# -*- coding: utf-8 -*-
"""The tracing decorators for agent, formatter, toolkit, chat and embedding
models."""
import inspect
from functools import wraps
from typing import (
Generator,
AsyncGenerator,
Callable,
Any,
Coroutine,
TypeVar,
TYPE_CHECKING,
)
import aioitertools
from .. import _config
from ..embedding import EmbeddingModelBase, EmbeddingResponse
from .._logging import logger
from ..message import Msg, ToolUseBlock
from ..model import ChatModelBase, ChatResponse
from ._attributes import SpanAttributes, OperationNameValues
from ._extractor import (
_get_common_attributes,
_get_agent_request_attributes,
_get_agent_span_name,
_get_agent_response_attributes,
_get_llm_request_attributes,
_get_llm_span_name,
_get_llm_response_attributes,
_get_tool_request_attributes,
_get_tool_span_name,
_get_tool_response_attributes,
_get_formatter_request_attributes,
_get_formatter_span_name,
_get_formatter_response_attributes,
_get_generic_function_request_attributes,
_get_generic_function_span_name,
_get_generic_function_response_attributes,
_get_embedding_request_attributes,
_get_embedding_span_name,
_get_embedding_response_attributes,
)
from ._setup import _get_tracer
if TYPE_CHECKING:
from opentelemetry.trace import Span
from ..agent import AgentBase
from ..formatter import FormatterBase
from ..tool import (
Toolkit,
ToolResponse,
)
else:
AgentBase = "AgentBase"
FormatterBase = "FormatterBase"
Span = "Span"
Toolkit = "Toolkit"
ToolResponse = "ToolResponse"
T = TypeVar("T")
def _check_tracing_enabled() -> bool:
"""Check if the OpenTelemetry tracer is initialized in AgentScope with an
endpoint.
TODO: We expect an OpenTelemetry official interface to check if the
tracer is initialized. Leaving this function here as a temporary
solution.
"""
return _config.trace_enabled
def _set_span_success_status(span: Span) -> None:
"""Set the status of the span.
Args:
span (`Span`):
The OpenTelemetry span to be used for tracing.
"""
from opentelemetry import trace as trace_api
span.set_status(trace_api.StatusCode.OK)
span.end()
def _set_span_error_status(span: Span, e: Exception) -> None:
"""Set the status of the span.
Args:
span (`Span`):
The OpenTelemetry span to be used for tracing.
e (`Exception`):
The exception to be recorded.
"""
from opentelemetry import trace as trace_api
span.set_status(trace_api.StatusCode.ERROR, str(e))
span.record_exception(e)
span.end()
def _trace_sync_generator_wrapper(
res: Generator[T, None, None],
span: Span,
) -> Generator[T, None, None]:
"""Trace the sync generator output with OpenTelemetry."""
has_error = False
try:
last_chunk = None
for chunk in res:
last_chunk = chunk
yield chunk
except Exception as e:
has_error = True
_set_span_error_status(span, e)
raise e from None
finally:
if not has_error:
# Set the last chunk as output
span.set_attributes(
_get_generic_function_response_attributes(last_chunk),
)
_set_span_success_status(span)
async def _trace_async_generator_wrapper(
res: AsyncGenerator[T, None],
span: Span,
) -> AsyncGenerator[T, None]:
"""Trace the async generator output with OpenTelemetry.
Args:
res (`AsyncGenerator[T, None]`):
The generator or async generator to be traced.
span (`Span`):
The OpenTelemetry span to be used for tracing.
Yields:
`T`:
The output of the async generator.
"""
has_error = False
try:
last_chunk = None
async for chunk in aioitertools.iter(res):
last_chunk = chunk
yield chunk
except Exception as e:
has_error = True
_set_span_error_status(span, e)
raise e from None
finally:
if not has_error:
# Set the last chunk as output
if (
getattr(span, "attributes", {}).get(
SpanAttributes.GEN_AI_OPERATION_NAME,
)
== OperationNameValues.CHAT
):
response_attributes = _get_llm_response_attributes(last_chunk)
elif (
getattr(span, "attributes", {}).get(
SpanAttributes.GEN_AI_OPERATION_NAME,
)
== OperationNameValues.EXECUTE_TOOL
):
response_attributes = _get_tool_response_attributes(last_chunk)
else:
response_attributes = (
_get_generic_function_response_attributes(
last_chunk,
)
)
span.set_attributes(response_attributes)
_set_span_success_status(span)
[文档]
def trace(
name: str | None = None,
) -> Callable:
"""A generic tracing decorator for synchronous and asynchronous functions.
Args:
name (`str | None`, optional):
The name of the span to be created. If not provided,
the name of the function will be used.
Returns:
`Callable`:
Returns a decorator that wraps the given function with
OpenTelemetry tracing.
"""
def decorator(
func: Callable,
) -> Callable:
"""A decorator that wraps the given function with OpenTelemetry tracing
Args:
func (`Callable`):
The function to be traced, which can be sync or async function,
and returns an object or a generator.
Returns:
`Callable`:
A wrapper function that traces the function call and handles
input/output and exceptions.
"""
# Async function
if inspect.iscoroutinefunction(func):
@wraps(func)
async def wrapper(
*args: Any,
**kwargs: Any,
) -> Any:
"""The wrapper function for tracing the sync function call."""
if not _check_tracing_enabled():
return await func(*args, **kwargs)
tracer = _get_tracer()
function_name = name if name else func.__name__
request_attributes = _get_generic_function_request_attributes(
function_name,
args,
kwargs,
)
span_name = _get_generic_function_span_name(request_attributes)
with tracer.start_as_current_span(
name=span_name,
attributes=request_attributes,
end_on_exit=False,
) as span:
try:
res = await func(*args, **kwargs)
# If generator or async generator
if isinstance(res, AsyncGenerator):
return _trace_async_generator_wrapper(res, span)
if isinstance(res, Generator):
return _trace_sync_generator_wrapper(res, span)
# non-generator result
span.set_attributes(
_get_generic_function_response_attributes(res),
)
_set_span_success_status(span)
return res
except Exception as e:
_set_span_error_status(span, e)
raise e from None
return wrapper
# Sync function
@wraps(func)
def sync_wrapper(
*args: Any,
**kwargs: Any,
) -> Any:
"""The wrapper function for tracing the sync function call."""
if not _check_tracing_enabled():
return func(*args, **kwargs)
tracer = _get_tracer()
function_name = name if name else func.__name__
request_attributes = _get_generic_function_request_attributes(
function_name,
args,
kwargs,
)
span_name = _get_generic_function_span_name(request_attributes)
with tracer.start_as_current_span(
name=span_name,
attributes=request_attributes,
end_on_exit=False,
) as span:
try:
res = func(*args, **kwargs)
# If generator or async generator
if isinstance(res, AsyncGenerator):
return _trace_async_generator_wrapper(res, span)
if isinstance(res, Generator):
return _trace_sync_generator_wrapper(res, span)
# non-generator result
span.set_attributes(
_get_generic_function_response_attributes(res),
)
_set_span_success_status(span)
return res
except Exception as e:
_set_span_error_status(span, e)
raise e from None
return sync_wrapper
return decorator
[文档]
def trace_reply(
func: Callable[..., Coroutine[Any, Any, Msg]],
) -> Callable[..., Coroutine[Any, Any, Msg]]:
"""Trace the agent reply call with OpenTelemetry.
Args:
func (`Callable[..., Coroutine[Any, Any, Msg]]`):
The agent async reply function to be traced.
Returns:
`Callable[..., Coroutine[Any, Any, Msg]]`:
A wrapper function that traces the agent reply call and handles
input/output and exceptions.
"""
@wraps(func)
async def wrapper(
self: "AgentBase",
*args: Any,
**kwargs: Any,
) -> Msg:
"""The wrapper function for tracing the agent reply function call."""
if not _check_tracing_enabled():
return await func(self, *args, **kwargs)
from ..agent import AgentBase
if not isinstance(self, AgentBase):
logger.warning(
"Skipping tracing for %s as the first argument"
"is not an instance of AgentBase, but %s",
func.__name__,
type(self),
)
return await func(self, *args, **kwargs)
tracer = _get_tracer()
# Prepare the attributes for the span
request_attributes = _get_agent_request_attributes(self, args, kwargs)
span_name = _get_agent_span_name(request_attributes)
function_name = f"{self.__class__.__name__}.{func.__name__}"
# Begin the llm call span
with tracer.start_as_current_span(
name=span_name,
attributes={
**request_attributes,
**_get_common_attributes(),
SpanAttributes.AGENTSCOPE_FUNCTION_NAME: function_name,
},
end_on_exit=False,
) as span:
try:
# Call the agent reply function
res = await func(self, *args, **kwargs)
# Set the output attribute
span.set_attributes(_get_agent_response_attributes(res))
_set_span_success_status(span)
return res
except Exception as e:
_set_span_error_status(span, e)
raise e from None
return wrapper
[文档]
def trace_embedding(
func: Callable[..., Coroutine[Any, Any, EmbeddingResponse]],
) -> Callable[..., Coroutine[Any, Any, EmbeddingResponse]]:
"""Trace the embedding call with OpenTelemetry."""
@wraps(func)
async def wrapper(
self: EmbeddingModelBase,
*args: Any,
**kwargs: Any,
) -> EmbeddingResponse:
"""The wrapper function for tracing the embedding call."""
if not _check_tracing_enabled():
return await func(self, *args, **kwargs)
if not isinstance(self, EmbeddingModelBase):
logger.warning(
"Skipping tracing for %s as the first argument"
"is not an instance of EmbeddingModelBase, but %s",
func.__name__,
type(self),
)
return await func(self, *args, **kwargs)
tracer = _get_tracer()
# Prepare the attributes for the span
request_attributes = _get_embedding_request_attributes(
self,
args,
kwargs,
)
span_name = _get_embedding_span_name(request_attributes)
function_name = f"{self.__class__.__name__}.{func.__name__}"
with tracer.start_as_current_span(
name=span_name,
attributes={
**request_attributes,
**_get_common_attributes(),
SpanAttributes.AGENTSCOPE_FUNCTION_NAME: function_name,
},
end_on_exit=False,
) as span:
try:
# Call the embedding function
res = await func(self, *args, **kwargs)
# Set the output attribute
span.set_attributes(_get_embedding_response_attributes(res))
_set_span_success_status(span)
return res
except Exception as e:
_set_span_error_status(span, e)
raise e from None
return wrapper
[文档]
def trace_llm(
func: Callable[
...,
Coroutine[
Any,
Any,
ChatResponse | AsyncGenerator[ChatResponse, None],
],
],
) -> Callable[
...,
Coroutine[Any, Any, ChatResponse | AsyncGenerator[ChatResponse, None]],
]:
"""Trace the LLM call with OpenTelemetry.
Args:
func (`Callable`):
The function to be traced, which should be a coroutine that
returns either a `ChatResponse` or an `AsyncGenerator`
of `ChatResponse`.
Returns:
`Callable`:
A wrapper function that traces the LLM call and handles
input/output and exceptions.
"""
@wraps(func)
async def async_wrapper(
self: ChatModelBase,
*args: Any,
**kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""The wrapper function for tracing the LLM call."""
if not _check_tracing_enabled():
return await func(self, *args, **kwargs)
if not isinstance(self, ChatModelBase):
logger.warning(
"Skipping tracing for %s as the first argument"
"is not an instance of ChatModelBase, but %s",
func.__name__,
type(self),
)
return await func(self, *args, **kwargs)
tracer = _get_tracer()
# Prepare the attributes for the span
request_attributes = _get_llm_request_attributes(self, args, kwargs)
span_name = _get_llm_span_name(request_attributes)
function_name = f"{self.__class__.__name__}.__call__"
# Begin the llm call span
with tracer.start_as_current_span(
name=span_name,
attributes={
**request_attributes,
**_get_common_attributes(),
SpanAttributes.AGENTSCOPE_FUNCTION_NAME: function_name,
},
end_on_exit=False,
) as span:
try:
# Must be an async calling
res = await func(self, *args, **kwargs)
# If the result is a AsyncGenerator
if isinstance(res, AsyncGenerator):
return _trace_async_generator_wrapper(res, span)
# non-generator result
span.set_attributes(_get_llm_response_attributes(res))
_set_span_success_status(span)
return res
except Exception as e:
_set_span_error_status(span, e)
raise e from None
return async_wrapper