diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/content_recording.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/content_recording.py new file mode 100644 index 0000000000..caf86b4ee3 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/content_recording.py @@ -0,0 +1,109 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thin integration layer over the shared genai content capture utilities. + +Provides clear APIs for the LangChain callback handler to decide what +content should be recorded on spans and events. +""" + +from typing import Optional + +from opentelemetry.util.genai.types import ContentCapturingMode +from opentelemetry.util.genai.utils import ( + get_content_capturing_mode, + is_experimental_mode, + should_emit_event, +) + + +class ContentPolicy: + """Determines what content should be recorded on spans and events. + + Wraps the shared genai utility functions to provide a clean API + for the callback handler. All properties are evaluated lazily so + that environment variable changes are picked up immediately. + """ + + @property + def should_record_content_on_spans(self) -> bool: + """Whether message/tool content should be recorded as span attributes.""" + return self.mode in ( + ContentCapturingMode.SPAN_ONLY, + ContentCapturingMode.SPAN_AND_EVENT, + ) + + @property + def should_emit_events(self) -> bool: + """Whether content events should be emitted.""" + return should_emit_event() + + @property + def record_content(self) -> bool: + """Whether content should be recorded at all (spans or events).""" + return self.should_record_content_on_spans or self.should_emit_events + + @property + def mode(self) -> ContentCapturingMode: + """The current content capturing mode. + + Returns ``NO_CONTENT`` when not running in experimental mode. + """ + if not is_experimental_mode(): + return ContentCapturingMode.NO_CONTENT + return get_content_capturing_mode() + + +# -- Helper functions for specific content types ------------------------------ +# All opt-in content types follow the same underlying policy today. Separate +# helpers are provided so call-sites read clearly and so that per-type +# overrides can be added later without changing every caller. + + +def should_record_messages(policy: ContentPolicy) -> bool: + """Whether input/output messages should be recorded on spans.""" + return policy.should_record_content_on_spans + + +def should_record_tool_content(policy: ContentPolicy) -> bool: + """Whether tool arguments and results should be recorded on spans.""" + return policy.should_record_content_on_spans + + +def should_record_retriever_content(policy: ContentPolicy) -> bool: + """Whether retriever queries and document content should be recorded.""" + return policy.should_record_content_on_spans + + +def should_record_system_instructions(policy: ContentPolicy) -> bool: + """Whether system instructions should be recorded on spans.""" + return policy.should_record_content_on_spans + + +# -- Default singleton -------------------------------------------------------- + +_default_policy: Optional[ContentPolicy] = None + + +def get_content_policy() -> ContentPolicy: + """Get the content policy based on current environment configuration. + + Returns a module-level singleton. Because the policy reads + environment variables lazily on every property access, a single + instance is sufficient. + """ + global _default_policy # noqa: PLW0603 + if _default_policy is None: + _default_policy = ContentPolicy() + return _default_policy diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/event_emitter.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/event_emitter.py new file mode 100644 index 0000000000..c4b931d08f --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/event_emitter.py @@ -0,0 +1,208 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Event emission for non-LLM GenAI operations in LangChain. + +Emits semantic-convention-aligned log-record events for tool, agent, and +retriever spans. LLM event emission is handled by the shared +``TelemetryHandler`` and is **not** duplicated here. + +All event emission is gated behind the content policy so that events are +only produced when the user opts in via the +``OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT`` / +``OTEL_INSTRUMENTATION_GENAI_EMIT_EVENT`` environment variables. +""" + +from __future__ import annotations + +from typing import Any, Optional + +from opentelemetry._logs import Logger, LoggerProvider, LogRecord, get_logger +from opentelemetry.context import get_current +from opentelemetry.instrumentation.langchain.content_recording import ( + get_content_policy, +) +from opentelemetry.instrumentation.langchain.version import __version__ +from opentelemetry.semconv.schemas import Schemas +from opentelemetry.trace import Span +from opentelemetry.trace.propagation import set_span_in_context + +_REDACTED = "[redacted]" + + +class EventEmitter: + """Emits GenAI semantic convention events for LangChain operations. + + Events are emitted as ``LogRecord`` instances linked to the active span + context, following the same pattern used by the OpenAI v2 instrumentor + and the shared ``_maybe_emit_llm_event`` helper in ``span_utils``. + """ + + def __init__( + self, logger_provider: Optional[LoggerProvider] = None + ) -> None: + self._logger: Logger = get_logger( + __name__, + __version__, + logger_provider, + schema_url=Schemas.V1_37_0.value, + ) + + # ------------------------------------------------------------------ + # Tool events + # ------------------------------------------------------------------ + + def emit_tool_call_event( + self, + span: Span, + tool_name: str, + arguments: Optional[str] = None, + tool_call_id: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.tool.call`` event when a tool is invoked.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": tool_name} + if tool_call_id: + body["id"] = tool_call_id + if arguments is not None: + body["arguments"] = ( + arguments if policy.record_content else _REDACTED + ) + + self._emit(span, "gen_ai.tool.call", body) + + def emit_tool_result_event( + self, + span: Span, + tool_name: str, + result: Optional[str] = None, + tool_call_id: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.tool.result`` event when a tool returns.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": tool_name} + if tool_call_id: + body["id"] = tool_call_id + if result is not None: + body["result"] = result if policy.record_content else _REDACTED + + self._emit(span, "gen_ai.tool.result", body) + + # ------------------------------------------------------------------ + # Agent events + # ------------------------------------------------------------------ + + def emit_agent_start_event( + self, + span: Span, + agent_name: str, + input_messages: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.agent.start`` event when an agent begins.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": agent_name} + if input_messages is not None: + body["input"] = ( + input_messages if policy.record_content else _REDACTED + ) + + self._emit(span, "gen_ai.agent.start", body) + + def emit_agent_end_event( + self, + span: Span, + agent_name: str, + output_messages: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.agent.end`` event when an agent completes.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": agent_name} + if output_messages is not None: + body["output"] = ( + output_messages if policy.record_content else _REDACTED + ) + + self._emit(span, "gen_ai.agent.end", body) + + # ------------------------------------------------------------------ + # Retriever events + # ------------------------------------------------------------------ + + def emit_retriever_query_event( + self, + span: Span, + retriever_name: str, + query: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.retriever.query`` event for a retriever query.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": retriever_name} + if query is not None: + body["query"] = query if policy.record_content else _REDACTED + + self._emit(span, "gen_ai.retriever.query", body) + + def emit_retriever_result_event( + self, + span: Span, + retriever_name: str, + documents: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.retriever.result`` event with retrieved docs.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": retriever_name} + if documents is not None: + body["documents"] = ( + documents if policy.record_content else _REDACTED + ) + + self._emit(span, "gen_ai.retriever.result", body) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _should_emit() -> bool: + """Check whether event emission is enabled via content policy.""" + return get_content_policy().should_emit_events + + def _emit(self, span: Span, event_name: str, body: dict[str, Any]) -> None: + """Create a ``LogRecord`` linked to *span* and emit it.""" + context = set_span_in_context(span, get_current()) + self._logger.emit( + LogRecord( + event_name=event_name, + body=body, + context=context, + ) + ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/message_formatting.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/message_formatting.py new file mode 100644 index 0000000000..30918c51ab --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/message_formatting.py @@ -0,0 +1,541 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Message, tool, and document serialization with content-redaction support. + +Converts LangChain message objects into the compact JSON format expected by +OpenTelemetry GenAI semantic convention span attributes +(``gen_ai.input_messages``, ``gen_ai.output_messages``, +``gen_ai.system_instructions``, ``gen_ai.tool_definitions``, etc.). + +Redaction behaviour +------------------- +When *record_content* is ``False``: + +* Text content → ``"[redacted]"`` +* Tool call arguments → ``"[redacted]"`` +* Tool call results → ``"[redacted]"`` +* Document page content → omitted (only metadata is kept) +* System instruction content → ``"[redacted]"`` +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, cast + +from opentelemetry.util.genai.utils import gen_ai_json_dumps + +logger = logging.getLogger(__name__) + +_REDACTED = "[redacted]" + + +def _as_dict(value: Any) -> Optional[Dict[str, Any]]: + if isinstance(value, dict): + return cast(Dict[str, Any], value) + return None + + +def _as_sequence(value: Any) -> Optional[Sequence[Any]]: + if isinstance(value, (list, tuple)): + return cast(Sequence[Any], value) + return None + + +# --------------------------------------------------------------------------- +# Role mapping +# --------------------------------------------------------------------------- + +# LangChain message type → OpenTelemetry GenAI role +_ROLE_MAP: Dict[str, str] = { + "human": "user", + "HumanMessage": "user", + "ai": "assistant", + "AIMessage": "assistant", + "AIMessageChunk": "assistant", + "system": "system", + "SystemMessage": "system", + "tool": "tool", + "ToolMessage": "tool", + "function": "tool", + "FunctionMessage": "tool", + "chat": "user", + "ChatMessage": "user", +} + + +def message_role(message: Any) -> str: + """Map a LangChain message to its GenAI role. + + Handles ``BaseMessage`` subclasses (via ``.type``), plain dicts + (via ``"role"`` or ``"type"`` keys), and falls back to ``"user"``. + """ + # BaseMessage subclass + msg_type = getattr(message, "type", None) + if isinstance(msg_type, str): + mapped = _ROLE_MAP.get(msg_type) + if mapped is not None: + return mapped + + # Dict-like message + message_dict = _as_dict(message) + if message_dict is not None: + for key in ("role", "type"): + value = message_dict.get(key) + if isinstance(value, str): + mapped = _ROLE_MAP.get(value) + if mapped is not None: + return mapped + # If the value itself is already a canonical role, accept it + if value in ("user", "assistant", "system", "tool"): + return value + + # Class-name fallback + cls_name = type(message).__name__ + mapped = _ROLE_MAP.get(cls_name) + if mapped is not None: + return mapped + + return "user" + + +# --------------------------------------------------------------------------- +# Content extraction +# --------------------------------------------------------------------------- + + +def message_content(message: Any) -> Optional[str]: + """Extract text content from a LangChain message. + + Returns ``None`` when no text content is available. Multi-part content + lists are concatenated with newlines. + """ + raw: Any = getattr(message, "content", None) + message_dict = _as_dict(message) + if raw is None and message_dict is not None: + raw = message_dict.get("content") + + if raw is None: + return None + + if isinstance(raw, str): + return raw if raw else None + + # Multi-part content (list of strings / dicts with "text" key) + raw_parts = _as_sequence(raw) + if raw_parts is not None: + parts: list[str] = [] + for item in raw_parts: + if isinstance(item, str): + parts.append(item) + else: + item_dict = _as_dict(item) + if item_dict is None: + continue + text_value = item_dict.get("text") + if isinstance(text_value, str) and text_value: + parts.append(text_value) + return "\n".join(parts) if parts else None + + return str(raw) if raw else None + + +# --------------------------------------------------------------------------- +# Tool-call extraction +# --------------------------------------------------------------------------- + + +def extract_tool_calls(message: Any) -> List[Dict[str, Any]]: + """Extract tool calls from an ``AIMessage`` or dict. + + Returns a (possibly empty) list of dicts, each with keys + ``"id"``, ``"name"``, and ``"arguments"``. + """ + tool_calls: Any = getattr(message, "tool_calls", None) + message_dict = _as_dict(message) + if tool_calls is None and message_dict is not None: + tool_calls = message_dict.get("tool_calls") + + tool_call_items = _as_sequence(tool_calls) + if not tool_call_items: + return [] + + result: List[Dict[str, Any]] = [] + for tc in tool_call_items: + entry: Dict[str, Any] = {} + + tc_dict = _as_dict(tc) + if tc_dict is not None: + entry["id"] = tc_dict.get("id") or "" + entry["name"] = tc_dict.get("name") or "" + entry["arguments"] = tc_dict.get("args") or tc_dict.get( + "arguments" + ) + else: + entry["id"] = getattr(tc, "id", "") or "" + entry["name"] = getattr(tc, "name", "") or "" + entry["arguments"] = getattr(tc, "args", None) or getattr( + tc, "arguments", None + ) + + result.append(entry) + return result + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _format_tool_call_part( + tc: Dict[str, Any], record_content: bool +) -> Dict[str, Any]: + """Build a serialised tool-call part dict.""" + part: Dict[str, Any] = {"type": "tool_call"} + if tc.get("id"): + part["id"] = tc["id"] + if tc.get("name"): + part["name"] = tc["name"] + + args = tc.get("arguments") + if record_content: + if args is not None: + part["arguments"] = args + else: + part["arguments"] = _REDACTED + + return part + + +def _format_tool_response_part( + message: Any, record_content: bool +) -> Dict[str, Any]: + """Build a serialised tool-call-response part dict.""" + part: Dict[str, Any] = {"type": "tool_call_response"} + + tool_call_id = getattr(message, "tool_call_id", None) + message_dict = _as_dict(message) + if tool_call_id is None and message_dict is not None: + tool_call_id = message_dict.get("tool_call_id") + if tool_call_id: + part["id"] = tool_call_id + + if record_content: + content = message_content(message) + if content is not None: + part["result"] = content + else: + part["result"] = _REDACTED + + return part + + +def _format_text_parts( + message: Any, record_content: bool +) -> List[Dict[str, Any]]: + """Build text-content part dicts for a message.""" + content = message_content(message) + if content is None: + return [] + + return [ + { + "type": "text", + "content": content if record_content else _REDACTED, + } + ] + + +def _format_single_message( + message: Any, record_content: bool +) -> Dict[str, Any]: + """Serialise one LangChain message into the GenAI convention dict.""" + role = message_role(message) + parts: List[Dict[str, Any]] = [] + + if role == "assistant": + # Tool calls first, then text + for tc in extract_tool_calls(message): + parts.append(_format_tool_call_part(tc, record_content)) + parts.extend(_format_text_parts(message, record_content)) + + elif role == "tool": + parts.append(_format_tool_response_part(message, record_content)) + + else: + # user, system, or any other role + parts.extend(_format_text_parts(message, record_content)) + + result: Dict[str, Any] = {"role": role} + if parts: + result["parts"] = parts + return result + + +def _flatten_messages(raw_messages: Any) -> List[Any]: + """Accept messages in multiple shapes and return a flat list. + + LangChain callbacks may pass ``list[list[BaseMessage]]`` (grouped by + prompt) or a simple ``list[BaseMessage]``. + """ + if not raw_messages: + return [] + + raw_sequence = _as_sequence(raw_messages) + if raw_sequence is None: + return [raw_messages] + + # Check for nested lists (list[list[BaseMessage]]) + flat: list[Any] = [] + for item in raw_sequence: + nested_items = _as_sequence(item) + if nested_items is not None: + flat.extend(nested_items) + else: + flat.append(item) + return flat + + +# --------------------------------------------------------------------------- +# Public API – prepare_messages +# --------------------------------------------------------------------------- + + +def prepare_messages( + raw_messages: Any, + *, + record_content: bool, + include_roles: Optional[Set[str]] = None, +) -> Tuple[Optional[str], Optional[str]]: + """Serialise LangChain messages to JSON strings for span attributes. + + Returns ``(formatted_json, system_instructions_json)``: + + * *formatted_json* – JSON array of non-system messages, suitable for + ``gen_ai.input_messages`` / ``gen_ai.output_messages``. + * *system_instructions_json* – JSON array of system-message *parts* + only, suitable for ``gen_ai.system_instructions``. + + Either value may be ``None`` when no messages of that kind exist. + + Parameters + ---------- + raw_messages: + Messages as received from LangChain callbacks. May be a flat list or + a nested ``list[list[BaseMessage]]``. + record_content: + When ``False``, text payloads and tool arguments/results are replaced + with ``"[redacted]"``. + include_roles: + Optional filter. When provided, only messages whose mapped role is in + the set are included. + """ + messages = _flatten_messages(raw_messages) + if not messages: + return None, None + + formatted: List[Dict[str, Any]] = [] + system_parts: List[Dict[str, Any]] = [] + + for msg in messages: + role = message_role(msg) + + if include_roles is not None and role not in include_roles: + continue + + if role == "system": + # System messages contribute to system_instructions only + content = message_content(msg) + if content is not None: + system_parts.append( + { + "type": "text", + "content": content if record_content else _REDACTED, + } + ) + continue + + formatted.append(_format_single_message(msg, record_content)) + + formatted_json = gen_ai_json_dumps(formatted) if formatted else None + system_json = gen_ai_json_dumps(system_parts) if system_parts else None + + return formatted_json, system_json + + +# --------------------------------------------------------------------------- +# Document formatting (for retrievers) +# --------------------------------------------------------------------------- + + +def format_documents( + documents: Optional[Sequence[Any]], *, record_content: bool +) -> Optional[str]: + """Format retrieved documents as a JSON string for span attributes. + + Each document is serialised as a dict with optional ``page_content`` + (when *record_content* is ``True``) and ``metadata`` fields. + + Returns ``None`` when *documents* is empty or ``None``. + """ + if not documents: + return None + + result: List[Dict[str, Any]] = [] + for doc in documents: + entry: Dict[str, Any] = {} + doc_dict = _as_dict(doc) + + # page_content + page_content = getattr(doc, "page_content", None) + if page_content is None and doc_dict is not None: + page_content = doc_dict.get("page_content") + + if record_content and page_content is not None: + entry["page_content"] = str(page_content) + + # metadata + metadata = getattr(doc, "metadata", None) + if metadata is None and doc_dict is not None: + metadata = doc_dict.get("metadata") + if metadata: + entry["metadata"] = metadata + + if entry: + result.append(entry) + + return gen_ai_json_dumps(result) if result else None + + +# --------------------------------------------------------------------------- +# Tool result serialization +# --------------------------------------------------------------------------- + + +def serialize_tool_result(output: Any, record_content: bool) -> str: + """Serialise a tool result for span attributes. + + When *record_content* is ``False`` the literal ``"[redacted]"`` is + returned. + """ + if not record_content: + return _REDACTED + + if isinstance(output, str): + return output + + # Try common attribute shapes produced by LangChain tools + content = getattr(output, "content", None) + if content is not None: + return str(content) + + output_dict = _as_dict(output) + if output_dict is not None: + content = output_dict.get("content") or output_dict.get("output") + if content is not None: + return str(content) + + # Fallback: JSON-encode arbitrary values + try: + return gen_ai_json_dumps(output) + except (TypeError, ValueError): + return str(output) + + +# --------------------------------------------------------------------------- +# Tool definitions formatting +# --------------------------------------------------------------------------- + + +def format_tool_definitions(definitions: Optional[Any]) -> Optional[str]: + """Format tool definitions for ``gen_ai.tool_definitions`` span attribute. + + Accepts a list of LangChain tool objects, dicts, or any mix thereof and + returns a compact JSON string. Returns ``None`` when *definitions* is + empty or ``None``. + """ + if not definitions: + return None + + definition_items = _as_sequence(definitions) + if definition_items is None: + definition_items = [definitions] + + result: List[Dict[str, Any]] = [] + for defn in definition_items: + entry: Dict[str, Any] = {} + defn_dict = _as_dict(defn) + + if defn_dict is not None: + # Already a dict – keep recognised keys + if "name" in defn_dict: + entry["name"] = defn_dict["name"] + if "description" in defn_dict: + entry["description"] = defn_dict["description"] + if "parameters" in defn_dict: + entry["parameters"] = defn_dict["parameters"] + + func_dict = _as_dict(defn_dict.get("function")) + if func_dict is not None: + func_name = func_dict.get("name") + if "name" not in entry and func_name is not None: + entry["name"] = func_name + func_description = func_dict.get("description") + if "description" not in entry and func_description is not None: + entry["description"] = func_description + func_parameters = func_dict.get("parameters") + if "parameters" not in entry and func_parameters is not None: + entry["parameters"] = func_parameters + + entry.setdefault("type", defn_dict.get("type", "function")) + else: + # Object with attributes (e.g. a LangChain BaseTool) + name = getattr(defn, "name", None) + if name is not None: + entry["name"] = str(name) + + description = getattr(defn, "description", None) + if description is not None: + entry["description"] = str(description) + + args_schema = getattr(defn, "args_schema", None) + if args_schema is not None: + schema_method = getattr(args_schema, "schema", None) + if callable(schema_method): + try: + entry["parameters"] = schema_method() + except Exception: # noqa: BLE001 + pass + + entry.setdefault("type", "function") + + if entry: + result.append(entry) + + return gen_ai_json_dumps(result) if result else None + + +# --------------------------------------------------------------------------- +# JSON helper +# --------------------------------------------------------------------------- + + +def as_json_attribute(value: Any) -> str: + """Return a JSON string suitable for OpenTelemetry string attributes. + + Uses the same compact encoder (no whitespace, base64 for bytes) as + the rest of the GenAI instrumentation. + """ + return gen_ai_json_dumps(value) diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/operation_mapping.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/operation_mapping.py new file mode 100644 index 0000000000..6402a5a93e --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/operation_mapping.py @@ -0,0 +1,256 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Callback-to-semconv operation mapping for LangChain callbacks. + +Maps each LangChain callback to the correct GenAI semantic convention +operation name. Direct callbacks (``on_chat_model_start``, +``on_llm_start``, ``on_tool_start``, ``on_retriever_start``) have a +fixed 1-to-1 mapping. ``on_chain_start`` requires heuristic +classification because LangChain emits this callback for agents, +workflows, and internal plumbing alike. +""" + +from __future__ import annotations + +from typing import Any, Optional +from uuid import UUID + +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) + +__all__ = [ + "OperationName", + "classify_chain_run", + "resolve_agent_name", + "should_ignore_chain", +] + +# --------------------------------------------------------------------------- +# Operation name constants (sourced from the GenAI semconv enum where +# available, with string fallbacks for values not yet in the enum). +# --------------------------------------------------------------------------- + + +class OperationName: + """Canonical GenAI semantic convention operation names.""" + + CHAT: str = GenAI.GenAiOperationNameValues.CHAT.value + TEXT_COMPLETION: str = GenAI.GenAiOperationNameValues.TEXT_COMPLETION.value + INVOKE_AGENT: str = GenAI.GenAiOperationNameValues.INVOKE_AGENT.value + EXECUTE_TOOL: str = GenAI.GenAiOperationNameValues.EXECUTE_TOOL.value + # invoke_workflow is not yet in the semconv enum; use the expected + # string value so the mapping is forward-compatible. + INVOKE_WORKFLOW: str = "invoke_workflow" + + +# --------------------------------------------------------------------------- +# LangGraph markers – names and prefixes produced by LangGraph that must +# be recognized when classifying ``on_chain_start`` callbacks. +# --------------------------------------------------------------------------- + +LANGGRAPH_NODE_KEY = "langgraph_node" +LANGGRAPH_START_NODE = "__start__" +MIDDLEWARE_PREFIX = "Middleware." +LANGGRAPH_IDENTIFIER = "LangGraph" + +# Metadata keys used by callers to override classification. +_META_AGENT_SPAN = "otel_agent_span" +_META_WORKFLOW_SPAN = "otel_workflow_span" +_META_AGENT_NAME = "agent_name" +_META_AGENT_TYPE = "agent_type" +_META_OTEL_TRACE = "otel_trace" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def resolve_agent_name( + serialized: dict[str, Any], + metadata: Optional[dict[str, Any]], + kwargs: dict[str, Any], +) -> Optional[str]: + """Derive the best-effort agent name from callback arguments. + + Checks (in priority order): + 1. ``metadata["agent_name"]`` + 2. ``kwargs["name"]`` + 3. ``serialized["name"]`` + 4. ``metadata["langgraph_node"]`` (if present and not a start node) + """ + if metadata: + name = metadata.get(_META_AGENT_NAME) + if name: + return str(name) + + name = kwargs.get("name") + if name: + return str(name) + + name = serialized.get("name") + if name: + return str(name) + + if metadata: + node = metadata.get(LANGGRAPH_NODE_KEY) + if node and node != LANGGRAPH_START_NODE: + return str(node) + + return None + + +def _has_agent_signals(metadata: Optional[dict[str, Any]]) -> bool: + """Return True when metadata contains any signal that the chain is an agent.""" + if not metadata: + return False + return bool( + metadata.get(_META_AGENT_SPAN) + or metadata.get(_META_AGENT_NAME) + or metadata.get(_META_AGENT_TYPE) + ) + + +def _is_langgraph_agent_node( + serialized: dict[str, Any], + metadata: Optional[dict[str, Any]], + kwargs: dict[str, Any], +) -> bool: + """Detect a LangGraph agent node that is not a start/middleware node.""" + if not metadata: + return False + + node = metadata.get(LANGGRAPH_NODE_KEY) + if not node: + return False + + # Exclude start and middleware nodes. + if node == LANGGRAPH_START_NODE: + return False + + name = resolve_agent_name(serialized, metadata, kwargs) + if name and name.startswith(MIDDLEWARE_PREFIX): + return False + + return True + + +def _looks_like_workflow( + serialized: dict[str, Any], + metadata: Optional[dict[str, Any]], + parent_run_id: Optional[UUID], +) -> bool: + """Return True if the chain looks like a top-level workflow/graph.""" + if parent_run_id is not None: + return False + + # An explicit workflow override is authoritative. + if metadata and metadata.get(_META_WORKFLOW_SPAN): + return True + + # Heuristic: check for LangGraph identifier in the serialized repr. + name = serialized.get("name", "") + graph_id = ( + serialized.get("graph", {}).get("id", "") + if isinstance(serialized.get("graph"), dict) + else "" + ) + return LANGGRAPH_IDENTIFIER in name or LANGGRAPH_IDENTIFIER in graph_id + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def should_ignore_chain( + metadata: Optional[dict[str, Any]], + agent_name: Optional[str], + parent_run_id: Optional[UUID], + kwargs: dict[str, Any], +) -> bool: + """Return True if the chain callback should be silently suppressed. + + Suppression happens when: + * The node is the LangGraph ``__start__`` node. + * The name carries the ``Middleware.`` prefix. + * ``metadata["otel_trace"]`` is explicitly ``False``. + * ``metadata["otel_agent_span"]`` is explicitly ``False`` and no other + agent signals are present. + """ + if metadata: + node = metadata.get(LANGGRAPH_NODE_KEY) + if node == LANGGRAPH_START_NODE: + return True + + if metadata.get(_META_OTEL_TRACE) is False: + return True + + if ( + metadata.get(_META_AGENT_SPAN) is False + and not metadata.get(_META_AGENT_NAME) + and not metadata.get(_META_AGENT_TYPE) + ): + return True + + if agent_name and agent_name.startswith(MIDDLEWARE_PREFIX): + return True + + name_from_kwargs = kwargs.get("name", "") + if isinstance(name_from_kwargs, str) and name_from_kwargs.startswith( + MIDDLEWARE_PREFIX + ): + return True + + return False + + +def classify_chain_run( + serialized: dict[str, Any], + metadata: Optional[dict[str, Any]], + kwargs: dict[str, Any], + parent_run_id: Optional[UUID] = None, +) -> Optional[str]: + """Classify a ``on_chain_start`` callback into a semconv operation. + + Returns one of the :class:`OperationName` constants, or ``None`` when + the chain should be suppressed (no span emitted). + + Classification order: + 1. Check for explicit suppression signals. + 2. Check for agent signals → ``invoke_agent``. + 3. Check for workflow signals → ``invoke_workflow``. + 4. Default: ``None`` (suppress – unclassified chains are not emitted). + """ + agent_name = resolve_agent_name(serialized, metadata, kwargs) + + # 1. Suppress known noise. + if should_ignore_chain(metadata, agent_name, parent_run_id, kwargs): + return None + + # 2. Agent detection. + if _has_agent_signals(metadata): + return OperationName.INVOKE_AGENT + + if _is_langgraph_agent_node(serialized, metadata, kwargs): + return OperationName.INVOKE_AGENT + + # 3. Workflow / orchestration detection. + if _looks_like_workflow(serialized, metadata, parent_run_id): + return OperationName.INVOKE_WORKFLOW + + # 4. Default: suppress unclassified chains. + return None diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/semconv_attributes.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/semconv_attributes.py new file mode 100644 index 0000000000..a5e5249eeb --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/semconv_attributes.py @@ -0,0 +1,313 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-operation attribute matrix based on OTel GenAI semantic conventions. + +Single source of truth for which attributes apply to which operations +in the LangChain instrumentor. Attribute requirement levels follow: +https://opentelemetry.io/docs/specs/semconv/gen-ai/ +""" + +from __future__ import annotations + +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) +from opentelemetry.semconv._incubating.attributes import ( + server_attributes as Server, +) +from opentelemetry.semconv.attributes import ( + error_attributes as Error, +) +from opentelemetry.trace import SpanKind + +# --------------------------------------------------------------------------- +# Operation name constants +# --------------------------------------------------------------------------- + +OP_CHAT = GenAI.GenAiOperationNameValues.CHAT.value # "chat" +OP_TEXT_COMPLETION = ( + GenAI.GenAiOperationNameValues.TEXT_COMPLETION.value +) # "text_completion" +OP_INVOKE_AGENT = ( + GenAI.GenAiOperationNameValues.INVOKE_AGENT.value +) # "invoke_agent" +OP_EXECUTE_TOOL = ( + GenAI.GenAiOperationNameValues.EXECUTE_TOOL.value +) # "execute_tool" + +# These operations are not yet in the semconv enum; define as literals. +OP_INVOKE_WORKFLOW = "invoke_workflow" +OP_RETRIEVAL = "retrieval" + +# --------------------------------------------------------------------------- +# Attribute key aliases (not yet in the released semconv package) +# --------------------------------------------------------------------------- + +GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS = "gen_ai.usage.cache_read.input_tokens" +GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS = ( + "gen_ai.usage.cache_creation.input_tokens" +) +GEN_AI_AGENT_VERSION = "gen_ai.agent.version" +GEN_AI_WORKFLOW_NAME = "gen_ai.workflow.name" + +# --------------------------------------------------------------------------- +# Attribute sets per operation, grouped by requirement level +# +# Requirement levels (per OpenTelemetry specification): +# REQUIRED – MUST be provided. +# CONDITIONALLY_REQ – MUST be provided when the stated condition is met. +# RECOMMENDED – SHOULD be provided. +# OPT_IN – MAY be provided; typically gated by a config flag. +# --------------------------------------------------------------------------- + +# ---- chat / text_completion (inference client spans) ---------------------- + +INFERENCE_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_OPERATION_NAME, + GenAI.GEN_AI_PROVIDER_NAME, + } +) + +INFERENCE_CONDITIONALLY_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_REQUEST_MODEL, # if available + Error.ERROR_TYPE, # if response is an error + Server.SERVER_PORT, # if server.address is set + GenAI.GEN_AI_REQUEST_SEED, # if present in request + GenAI.GEN_AI_REQUEST_CHOICE_COUNT, # if != 1 + GenAI.GEN_AI_OUTPUT_TYPE, # if applicable + GenAI.GEN_AI_CONVERSATION_ID, # if available + } +) + +INFERENCE_RECOMMENDED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_REQUEST_MAX_TOKENS, + GenAI.GEN_AI_REQUEST_TEMPERATURE, + GenAI.GEN_AI_REQUEST_TOP_P, + GenAI.GEN_AI_REQUEST_TOP_K, + GenAI.GEN_AI_REQUEST_STOP_SEQUENCES, + GenAI.GEN_AI_REQUEST_FREQUENCY_PENALTY, + GenAI.GEN_AI_REQUEST_PRESENCE_PENALTY, + GenAI.GEN_AI_RESPONSE_ID, + GenAI.GEN_AI_RESPONSE_MODEL, + GenAI.GEN_AI_RESPONSE_FINISH_REASONS, + GenAI.GEN_AI_USAGE_INPUT_TOKENS, + GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, + GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS, + GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS, + Server.SERVER_ADDRESS, + } +) + +INFERENCE_OPT_IN: frozenset[str] = frozenset( + { + GenAI.GEN_AI_SYSTEM_INSTRUCTIONS, + GenAI.GEN_AI_INPUT_MESSAGES, + GenAI.GEN_AI_OUTPUT_MESSAGES, + GenAI.GEN_AI_TOOL_DEFINITIONS, + } +) + +# ---- invoke_agent -------------------------------------------------------- + +AGENT_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_OPERATION_NAME, + GenAI.GEN_AI_PROVIDER_NAME, + } +) + +AGENT_CONDITIONALLY_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_AGENT_ID, + GenAI.GEN_AI_AGENT_NAME, + GenAI.GEN_AI_AGENT_DESCRIPTION, + GEN_AI_AGENT_VERSION, + GenAI.GEN_AI_REQUEST_MODEL, + GenAI.GEN_AI_DATA_SOURCE_ID, + Error.ERROR_TYPE, # if response is an error + GenAI.GEN_AI_CONVERSATION_ID, + } +) + +AGENT_RECOMMENDED: frozenset[str] = frozenset( + { + Server.SERVER_ADDRESS, + # All inference request/response attributes are also recommended + GenAI.GEN_AI_REQUEST_MAX_TOKENS, + GenAI.GEN_AI_REQUEST_TEMPERATURE, + GenAI.GEN_AI_REQUEST_TOP_P, + GenAI.GEN_AI_REQUEST_TOP_K, + GenAI.GEN_AI_REQUEST_STOP_SEQUENCES, + GenAI.GEN_AI_REQUEST_FREQUENCY_PENALTY, + GenAI.GEN_AI_REQUEST_PRESENCE_PENALTY, + GenAI.GEN_AI_RESPONSE_ID, + GenAI.GEN_AI_RESPONSE_MODEL, + GenAI.GEN_AI_RESPONSE_FINISH_REASONS, + GenAI.GEN_AI_USAGE_INPUT_TOKENS, + GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, + GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS, + GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS, + } +) + +AGENT_OPT_IN: frozenset[str] = frozenset( + { + GenAI.GEN_AI_SYSTEM_INSTRUCTIONS, + GenAI.GEN_AI_INPUT_MESSAGES, + GenAI.GEN_AI_OUTPUT_MESSAGES, + } +) + +# ---- execute_tool -------------------------------------------------------- + +TOOL_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_OPERATION_NAME, + } +) + +TOOL_CONDITIONALLY_REQUIRED: frozenset[str] = frozenset( + { + Error.ERROR_TYPE, # if response is an error + } +) + +TOOL_RECOMMENDED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_TOOL_NAME, + GenAI.GEN_AI_TOOL_CALL_ID, + GenAI.GEN_AI_TOOL_DESCRIPTION, + GenAI.GEN_AI_TOOL_TYPE, + } +) + +TOOL_OPT_IN: frozenset[str] = frozenset( + { + GenAI.GEN_AI_TOOL_CALL_ARGUMENTS, + GenAI.GEN_AI_TOOL_CALL_RESULT, + } +) + +# ---- invoke_workflow ----------------------------------------------------- + +WORKFLOW_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_OPERATION_NAME, + } +) + +WORKFLOW_CONDITIONALLY_REQUIRED: frozenset[str] = frozenset( + { + Error.ERROR_TYPE, # if response is an error + GEN_AI_WORKFLOW_NAME, # if available + } +) + +WORKFLOW_RECOMMENDED: frozenset[str] = frozenset() + +WORKFLOW_OPT_IN: frozenset[str] = frozenset( + { + GenAI.GEN_AI_INPUT_MESSAGES, + GenAI.GEN_AI_OUTPUT_MESSAGES, + } +) + +# --------------------------------------------------------------------------- +# Aggregate lookup: operation → (required, conditionally_required, +# recommended, opt_in) +# --------------------------------------------------------------------------- + +OPERATION_ATTRIBUTES: dict[ + str, + tuple[ + frozenset[str], + frozenset[str], + frozenset[str], + frozenset[str], + ], +] = { + OP_CHAT: ( + INFERENCE_REQUIRED, + INFERENCE_CONDITIONALLY_REQUIRED, + INFERENCE_RECOMMENDED, + INFERENCE_OPT_IN, + ), + OP_TEXT_COMPLETION: ( + INFERENCE_REQUIRED, + INFERENCE_CONDITIONALLY_REQUIRED, + INFERENCE_RECOMMENDED, + INFERENCE_OPT_IN, + ), + OP_INVOKE_AGENT: ( + AGENT_REQUIRED, + AGENT_CONDITIONALLY_REQUIRED, + AGENT_RECOMMENDED, + AGENT_OPT_IN, + ), + OP_EXECUTE_TOOL: ( + TOOL_REQUIRED, + TOOL_CONDITIONALLY_REQUIRED, + TOOL_RECOMMENDED, + TOOL_OPT_IN, + ), + OP_INVOKE_WORKFLOW: ( + WORKFLOW_REQUIRED, + WORKFLOW_CONDITIONALLY_REQUIRED, + WORKFLOW_RECOMMENDED, + WORKFLOW_OPT_IN, + ), +} + +# --------------------------------------------------------------------------- +# SpanKind helper +# --------------------------------------------------------------------------- + +_CLIENT_OPERATIONS: frozenset[str] = frozenset( + {OP_CHAT, OP_TEXT_COMPLETION, OP_INVOKE_AGENT} +) + + +def get_operation_span_kind(operation: str) -> SpanKind: + """Return the correct SpanKind for the given operation. + + * ``chat``, ``text_completion``, ``invoke_agent`` → ``SpanKind.CLIENT`` + * ``execute_tool``, ``invoke_workflow``, and others → ``SpanKind.INTERNAL`` + """ + if operation in _CLIENT_OPERATIONS: + return SpanKind.CLIENT + return SpanKind.INTERNAL + + +# --------------------------------------------------------------------------- +# Metric applicability +# +# Maps metric instrument names to the set of operations they apply to. +# --------------------------------------------------------------------------- + +METRIC_OPERATION_DURATION = "gen_ai.client.operation.duration" +METRIC_TOKEN_USAGE = "gen_ai.client.token.usage" +METRIC_TIME_TO_FIRST_CHUNK = "gen_ai.client.operation.time_to_first_chunk" +METRIC_TIME_PER_OUTPUT_CHUNK = "gen_ai.client.operation.time_per_output_chunk" + +METRIC_APPLICABLE_OPERATIONS: dict[str, frozenset[str]] = { + METRIC_OPERATION_DURATION: frozenset({OP_CHAT, OP_TEXT_COMPLETION}), + METRIC_TOKEN_USAGE: frozenset({OP_CHAT, OP_TEXT_COMPLETION}), + # Streaming-only metrics (chat / text_completion) + METRIC_TIME_TO_FIRST_CHUNK: frozenset({OP_CHAT, OP_TEXT_COMPLETION}), + METRIC_TIME_PER_OUTPUT_CHUNK: frozenset({OP_CHAT, OP_TEXT_COMPLETION}), +} diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/span_manager.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/span_manager.py index 7ce588b618..a9c6758aa6 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/span_manager.py +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/span_manager.py @@ -12,106 +12,417 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Span lifecycle manager for the LangChain instrumentor. + +Manages creation, parent-context resolution, ignored-run walk-through, +per-thread agent stacks, and clean teardown for all GenAI operation types. +""" + +from __future__ import annotations + +import threading +import time from dataclasses import dataclass, field -from typing import Dict, List, Optional -from uuid import UUID +from typing import Any, Dict, List, Optional +from opentelemetry.instrumentation.langchain.semconv_attributes import ( + OP_CHAT, + OP_INVOKE_AGENT, + OP_TEXT_COMPLETION, + get_operation_span_kind, +) from opentelemetry.semconv._incubating.attributes import ( gen_ai_attributes as GenAI, ) -from opentelemetry.semconv.attributes import ( - error_attributes, -) +from opentelemetry.semconv.attributes import error_attributes from opentelemetry.trace import Span, SpanKind, Tracer, set_span_in_context from opentelemetry.trace.status import Status, StatusCode __all__ = ["_SpanManager"] +# Operations that produce model-level duration metrics. +_MODEL_OPERATIONS: frozenset[str] = frozenset({OP_CHAT, OP_TEXT_COMPLETION}) + + +def _empty_attributes() -> Dict[str, Any]: + return {} + @dataclass -class _SpanState: +class SpanRecord: + """Rich record stored for every active span.""" + + run_id: str span: Span - children: List[UUID] = field(default_factory=lambda: list()) + operation: str + parent_run_id: Optional[str] = None + attributes: Dict[str, Any] = field(default_factory=_empty_attributes) + # Mutable scratch space for streaming timing, thread keys, etc. + stash: Dict[str, Any] = field(default_factory=_empty_attributes) class _SpanManager: - def __init__( + """Thread-safe span lifecycle manager for every GenAI operation type.""" + + def __init__(self, tracer: Tracer) -> None: + self._tracer = tracer + self._lock = threading.Lock() + + # run_id (str) → SpanRecord + self._spans: Dict[str, SpanRecord] = {} + + # Runs we decided to skip (e.g. internal LangChain plumbing) but + # whose children should still be linked to the correct parent. + self._ignored_runs: set[str] = set() + # Maps an ignored run_id to the parent_run_id it was called with, + # so children can walk through to the real ancestor. + self._run_parent_override: Dict[str, Optional[str]] = {} + + # Per-thread stacks of invoke_agent run_ids for hierarchy tracking + # in concurrent execution. key = thread_key (str). + self._agent_stack_by_thread: Dict[str, List[str]] = {} + + # Per-thread stacks for LangGraph Command(goto=...) transitions. + # key = thread_key (str), value = stack of parent_run_ids. + self._goto_parent_stack: Dict[str, List[str]] = {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def start_span( + self, + run_id: str | object, + name: str, + operation: str, + kind: Optional[SpanKind] = None, + parent_run_id: Optional[str | object] = None, + attributes: Optional[Dict[str, Any]] = None, + thread_key: Optional[str] = None, + ) -> SpanRecord: + """Create and register a new span. + + Parameters + ---------- + run_id: + Unique identifier for this run (UUID or str). + name: + Human-readable span name (e.g. ``"chat gpt-4o"``). + operation: + One of the ``OP_*`` constants from ``semconv_attributes``. + kind: + Override the SpanKind. When *None* the kind is derived from + *operation* via ``get_operation_span_kind``. + parent_run_id: + The run_id of the logical parent (may be an ignored run). + attributes: + Initial span attributes to set immediately. + thread_key: + Identifies the execution thread; used for agent stack tracking. + """ + rid = str(run_id) + prid = str(parent_run_id) if parent_run_id is not None else None + + if kind is None: + kind = get_operation_span_kind(operation) + + # Walk through ignored runs so children attach to the correct + # visible ancestor. + resolved_prid = self._resolve_parent_id(prid) + + # Build parent context. + ctx = None + with self._lock: + if resolved_prid is not None: + parent_record = self._spans.get(resolved_prid) + if parent_record is not None: + ctx = set_span_in_context(parent_record.span) + + span = self._tracer.start_span(name=name, kind=kind, context=ctx) + + attrs = attributes or {} + for attr_key, attr_val in attrs.items(): + span.set_attribute(attr_key, attr_val) + + stash: Dict[str, Any] = {} + if operation in _MODEL_OPERATIONS: + stash["started_at"] = time.perf_counter() + if thread_key is not None: + stash["thread_key"] = thread_key + + record = SpanRecord( + run_id=rid, + span=span, + operation=operation, + parent_run_id=prid, + attributes=attrs, + stash=stash, + ) + + with self._lock: + self._spans[rid] = record + + # Maintain per-thread agent stack. + if operation == OP_INVOKE_AGENT and thread_key is not None: + self._agent_stack_by_thread.setdefault(thread_key, []).append( + rid + ) + + return record + + def end_span( self, - tracer: Tracer, + run_id: str | object, + status: Optional[StatusCode] = None, + error: Optional[BaseException] = None, ) -> None: - self._tracer = tracer + """Finalise and end the span identified by *run_id*. + + Parameters + ---------- + run_id: + The run whose span should be ended. + status: + Explicit status code. When *error* is provided this defaults to + ``StatusCode.ERROR``. + error: + If supplied the span is marked as failed with ``error.type`` + recorded as an attribute. + """ + rid = str(run_id) + + with self._lock: + record = self._spans.pop(rid, None) + if record is None: + return + + span = record.span - # Map from run_id -> _SpanState, to keep track of spans and parent/child relationships - # TODO: Use weak references or a TTL cache to avoid memory leaks in long-running processes. See #3735 - self.spans: Dict[UUID, _SpanState] = {} + if error is not None: + span.set_attribute( + error_attributes.ERROR_TYPE, type(error).__qualname__ + ) + span.set_status(Status(StatusCode.ERROR, str(error))) + elif status is not None: + span.set_status(Status(status)) + + # Pop from agent stack if applicable. + thread_key = record.stash.get("thread_key") + if record.operation == OP_INVOKE_AGENT and thread_key is not None: + with self._lock: + stack = self._agent_stack_by_thread.get(thread_key) + if stack: + try: + stack.remove(rid) + except ValueError: + pass + if not stack: + del self._agent_stack_by_thread[thread_key] + + span.end() + + def get_record(self, run_id: str | object) -> Optional[SpanRecord]: + """Return the ``SpanRecord`` for *run_id*, or ``None``.""" + rid = str(run_id) + with self._lock: + return self._spans.get(rid) + + # ------------------------------------------------------------------ + # Ignored-run management + # ------------------------------------------------------------------ + + def ignore_run( + self, + run_id: str | object, + parent_run_id: Optional[str | object] = None, + ) -> None: + """Mark *run_id* as ignored. + + Any future child whose ``parent_run_id`` points at an ignored run + will be re-parented to the ignored run's own parent via + ``resolve_parent_id``. + """ + rid = str(run_id) + prid = str(parent_run_id) if parent_run_id is not None else None + with self._lock: + self._ignored_runs.add(rid) + self._run_parent_override[rid] = prid - def _create_span( + def is_ignored(self, run_id: str | object) -> bool: + rid = str(run_id) + with self._lock: + return rid in self._ignored_runs + + def clear_ignored_run(self, run_id: str | object) -> None: + """Remove ignored-run bookkeeping for *run_id*.""" + rid = str(run_id) + with self._lock: + self._ignored_runs.discard(rid) + self._run_parent_override.pop(rid, None) + + def resolve_parent_id( + self, parent_run_id: Optional[str | object] + ) -> Optional[str]: + """Public wrapper around the internal resolver.""" + prid = str(parent_run_id) if parent_run_id is not None else None + return self._resolve_parent_id(prid) + + # ------------------------------------------------------------------ + # Token usage accumulation + # ------------------------------------------------------------------ + + def _accumulate_on_record( self, - run_id: UUID, - parent_run_id: Optional[UUID], - span_name: str, - kind: SpanKind = SpanKind.INTERNAL, - ) -> Span: - if parent_run_id is not None and parent_run_id in self.spans: - parent_state = self.spans[parent_run_id] - parent_span = parent_state.span - ctx = set_span_in_context(parent_span) - span = self._tracer.start_span( - name=span_name, kind=kind, context=ctx + record: SpanRecord, + input_tokens: Optional[int], + output_tokens: Optional[int], + ) -> None: + """Add token counts to *record*. Caller **must** hold ``self._lock``.""" + if input_tokens is not None: + existing = record.attributes.get( + GenAI.GEN_AI_USAGE_INPUT_TOKENS, 0 ) - parent_state.children.append(run_id) - else: - # top-level or missing parent - span = self._tracer.start_span(name=span_name, kind=kind) - set_span_in_context(span) + new_val = (existing or 0) + input_tokens + record.span.set_attribute(GenAI.GEN_AI_USAGE_INPUT_TOKENS, new_val) + record.attributes[GenAI.GEN_AI_USAGE_INPUT_TOKENS] = new_val + if output_tokens is not None: + existing = record.attributes.get( + GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, 0 + ) + new_val = (existing or 0) + output_tokens + record.span.set_attribute( + GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, new_val + ) + record.attributes[GenAI.GEN_AI_USAGE_OUTPUT_TOKENS] = new_val - span_state = _SpanState(span=span) - self.spans[run_id] = span_state + def accumulate_usage_to_parent( + self, + record: SpanRecord, + input_tokens: Optional[int], + output_tokens: Optional[int], + ) -> None: + """Propagate token usage from a model span to its parent agent span.""" + if input_tokens is None and output_tokens is None: + return - return span + parent_key = record.parent_run_id + visited: set[str] = set() + with self._lock: + while parent_key: + if parent_key in visited: + break + visited.add(parent_key) + parent_record = self._spans.get(parent_key) + if not parent_record: + break + if parent_record.operation == OP_INVOKE_AGENT: + self._accumulate_on_record( + parent_record, input_tokens, output_tokens + ) + break + parent_key = parent_record.parent_run_id - def create_chat_span( + def accumulate_llm_usage_to_agent( self, - run_id: UUID, - parent_run_id: Optional[UUID], - request_model: str, - ) -> Span: - span = self._create_span( - run_id=run_id, - parent_run_id=parent_run_id, - span_name=f"{GenAI.GenAiOperationNameValues.CHAT.value} {request_model}", - kind=SpanKind.CLIENT, - ) - span.set_attribute( - GenAI.GEN_AI_OPERATION_NAME, - GenAI.GenAiOperationNameValues.CHAT.value, - ) - if request_model: - span.set_attribute(GenAI.GEN_AI_REQUEST_MODEL, request_model) - - return span - - def end_span(self, run_id: UUID) -> None: - state = self.spans[run_id] - for child_id in state.children: - child_state = self.spans.get(child_id) - if child_state: - child_state.span.end() - del self.spans[child_id] - state.span.end() - del self.spans[run_id] - - def get_span(self, run_id: UUID) -> Optional[Span]: - state = self.spans.get(run_id) - return state.span if state else None - - def handle_error(self, error: BaseException, run_id: UUID): - span = self.get_span(run_id) - if span is None: - # If the span does not exist, we cannot set the error status + parent_run_id: Optional[str | object], + input_tokens: Optional[int], + output_tokens: Optional[int], + ) -> None: + """Propagate LLM token usage up to the nearest agent span. + + Unlike ``accumulate_usage_to_parent`` (which starts from a + ``SpanRecord``'s parent), this resolves through ignored runs first + and then walks up to find the nearest ``invoke_agent`` ancestor. + Designed to be called from ``on_llm_end`` where the LLM span is + managed by :class:`TelemetryHandler`, not :class:`_SpanManager`. + """ + if input_tokens is None and output_tokens is None: return - span.set_status(Status(StatusCode.ERROR, str(error))) - span.set_attribute( - error_attributes.ERROR_TYPE, type(error).__qualname__ - ) - self.end_span(run_id) + + prid = str(parent_run_id) if parent_run_id is not None else None + resolved = self._resolve_parent_id(prid) + if resolved is None: + return + + visited: set[str] = set() + current = resolved + with self._lock: + while current: + if current in visited: + break + visited.add(current) + record = self._spans.get(current) + if not record: + break + if record.operation == OP_INVOKE_AGENT: + self._accumulate_on_record( + record, input_tokens, output_tokens + ) + return + current = record.parent_run_id + + def nearest_agent_parent(self, record: SpanRecord) -> Optional[str]: + """Walk up the parent chain to find the nearest invoke_agent ancestor. + + Returns the run_id of the nearest agent span, or *None*. + """ + parent_key = record.parent_run_id + visited: set[str] = set() + with self._lock: + while parent_key: + if parent_key in visited: + break + visited.add(parent_key) + parent_record = self._spans.get(parent_key) + if not parent_record: + break + if parent_record.operation == OP_INVOKE_AGENT: + return parent_key + parent_key = parent_record.parent_run_id + return None + + # ------------------------------------------------------------------ + # LangGraph goto support + # ------------------------------------------------------------------ + + def push_goto_parent(self, thread_key: str, parent_run_id: str) -> None: + """Push a goto parent onto the per-thread stack.""" + with self._lock: + self._goto_parent_stack.setdefault(thread_key, []).append( + parent_run_id + ) + + def pop_goto_parent(self, thread_key: str) -> Optional[str]: + """Pop and return the most recent goto parent, or *None*.""" + with self._lock: + stack = self._goto_parent_stack.get(thread_key) + if stack: + val = stack.pop() + if not stack: + del self._goto_parent_stack[thread_key] + return val + return None + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _resolve_parent_id( + self, parent_run_id: Optional[str] + ) -> Optional[str]: + """Walk through ignored runs to find the nearest visible ancestor.""" + if parent_run_id is None: + return None + + visited: set[str] = set() + current = parent_run_id + with self._lock: + while current in self._ignored_runs: + if current in visited: + # Cycle guard. + return None + visited.add(current) + current = self._run_parent_override.get(current) + if current is None: + return None + return current diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/utils.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/utils.py new file mode 100644 index 0000000000..23345b38cd --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/utils.py @@ -0,0 +1,354 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Mapping, Optional, cast +from urllib.parse import urlparse + +from opentelemetry.context import attach, detach +from opentelemetry.propagate import extract + +# Provider name constants aligned with OpenTelemetry semantic conventions +_PROVIDER_AZURE_OPENAI = "azure.ai.openai" +_PROVIDER_OPENAI = "openai" +_PROVIDER_AWS_BEDROCK = "aws.bedrock" +_PROVIDER_GCP_GEN_AI = "gcp.gen_ai" +_PROVIDER_ANTHROPIC = "anthropic" +_PROVIDER_COHERE = "cohere" +_PROVIDER_OLLAMA = "ollama" + +# Mapping from LangChain ls_provider values to normalized provider names +_LS_PROVIDER_MAP: Dict[str, str] = { + "azure": _PROVIDER_AZURE_OPENAI, + "azure_openai": _PROVIDER_AZURE_OPENAI, + "azure-openai": _PROVIDER_AZURE_OPENAI, + "openai": _PROVIDER_OPENAI, + "github": _PROVIDER_AZURE_OPENAI, + "google": _PROVIDER_GCP_GEN_AI, + "google_genai": _PROVIDER_GCP_GEN_AI, + "anthropic": _PROVIDER_ANTHROPIC, + "cohere": _PROVIDER_COHERE, + "ollama": _PROVIDER_OLLAMA, +} + +# Substrings in base_url mapped to provider names (checked in order) +_URL_PROVIDER_RULES: List[tuple[str, str]] = [ + ("azure", _PROVIDER_AZURE_OPENAI), + ("openai", _PROVIDER_OPENAI), + ("ollama", _PROVIDER_OLLAMA), + ("bedrock", _PROVIDER_AWS_BEDROCK), + ("amazonaws.com", _PROVIDER_AWS_BEDROCK), + ("anthropic", _PROVIDER_ANTHROPIC), + ("googleapis", _PROVIDER_GCP_GEN_AI), +] + +# Substrings in serialized class identifiers mapped to provider names +_CLASS_PROVIDER_RULES: List[tuple[str, str]] = [ + ("ChatOpenAI", _PROVIDER_OPENAI), + ("ChatBedrock", _PROVIDER_AWS_BEDROCK), + ("Bedrock", _PROVIDER_AWS_BEDROCK), + ("ChatAnthropic", _PROVIDER_ANTHROPIC), + ("ChatGoogleGenerativeAI", _PROVIDER_GCP_GEN_AI), + ("ChatVertexAI", _PROVIDER_GCP_GEN_AI), + ("Ollama", _PROVIDER_OLLAMA), +] + + +def _as_dict(value: Any) -> Optional[Dict[str, Any]]: + if isinstance(value, dict): + return cast(Dict[str, Any], value) + return None + + +def _get_class_identifier(serialized: Dict[str, Any]) -> Optional[str]: + """Extract a class identifier string from serialized data. + + Checks ``serialized["id"]`` (a list of path components) first, + then falls back to ``serialized["name"]``. + """ + id_parts = serialized.get("id") + if isinstance(id_parts, list) and id_parts: + return str(cast(List[Any], id_parts)[-1]) + name = serialized.get("name") + if name: + return str(name) + return None + + +def _infer_from_ls_provider( + metadata: Optional[Dict[str, Any]], +) -> Optional[str]: + """Infer provider from LangChain's ls_provider metadata hint.""" + if metadata is None: + return None + ls_provider = metadata.get("ls_provider") + if ls_provider is None: + return None + + ls_lower = str(ls_provider).lower() + + # Direct map lookup + mapped = _LS_PROVIDER_MAP.get(ls_lower) + if mapped is not None: + return mapped + + # Substring check for bedrock variants (e.g. "amazon_bedrock") + if "bedrock" in ls_lower: + return _PROVIDER_AWS_BEDROCK + + return None + + +def _infer_from_url( + invocation_params: Optional[Dict[str, Any]], +) -> Optional[str]: + """Infer provider from a base URL in invocation params.""" + if invocation_params is None: + return None + base_url = invocation_params.get("base_url") or invocation_params.get( + "openai_api_base" + ) + if not base_url: + return None + + url_lower = str(base_url).lower() + for substring, provider in _URL_PROVIDER_RULES: + if substring in url_lower: + return provider + return None + + +def _infer_from_class( + serialized: Optional[Dict[str, Any]], +) -> Optional[str]: + """Infer provider from the serialized class name or id.""" + if serialized is None: + return None + class_id = _get_class_identifier(serialized) + if class_id is None: + return None + + for substring, provider in _CLASS_PROVIDER_RULES: + if substring in class_id: + return provider + return None + + +def _infer_from_kwargs( + serialized: Optional[Dict[str, Any]], +) -> Optional[str]: + """Infer provider from serialized kwargs (endpoint fields).""" + if serialized is None: + return None + ser_kwargs = _as_dict(serialized.get("kwargs")) + if ser_kwargs is None: + return None + + if ser_kwargs.get("azure_endpoint"): + return _PROVIDER_AZURE_OPENAI + + openai_api_base = ser_kwargs.get("openai_api_base") + if isinstance(openai_api_base, str) and openai_api_base.endswith( + ".azure.com" + ): + return _PROVIDER_AZURE_OPENAI + + return None + + +def infer_provider_name( + serialized: Optional[Dict[str, Any]], + metadata: Optional[Dict[str, Any]], + invocation_params: Optional[Dict[str, Any]], +) -> Optional[str]: + """Infer the GenAI provider name from available LangChain callback data. + + Sources are checked in decreasing order of specificity: + 1. ``metadata["ls_provider"]`` — LangChain's own provider hint + 2. ``invocation_params["base_url"]`` — URL-based inference + 3. ``serialized["id"]`` / ``serialized["name"]`` — class name based + 4. ``serialized["kwargs"]`` — endpoint-based + + Returns ``None`` if the provider cannot be determined. + """ + return ( + _infer_from_ls_provider(metadata) + or _infer_from_url(invocation_params) + or _infer_from_class(serialized) + or _infer_from_kwargs(serialized) + ) + + +def _extract_url( + serialized: Optional[Dict[str, Any]], + invocation_params: Optional[Dict[str, Any]], +) -> Optional[str]: + """Find the first available URL from invocation params or serialized kwargs.""" + if invocation_params: + url = invocation_params.get("base_url") or invocation_params.get( + "openai_api_base" + ) + if url: + return str(url) + + if serialized: + ser_kwargs = _as_dict(serialized.get("kwargs")) + if ser_kwargs is not None: + url = ser_kwargs.get("openai_api_base") or ser_kwargs.get( + "azure_endpoint" + ) + if url: + return str(url) + + return None + + +def infer_server_address( + serialized: Optional[Dict[str, Any]], + invocation_params: Optional[Dict[str, Any]], +) -> Optional[str]: + """Extract the server hostname from available URL sources. + + Checks ``invocation_params["base_url"]``, + ``invocation_params["openai_api_base"]``, + ``serialized["kwargs"]["openai_api_base"]``, and + ``serialized["kwargs"]["azure_endpoint"]``. + """ + url = _extract_url(serialized, invocation_params) + if url is None: + return None + + parsed = urlparse(url) + return parsed.hostname or None + + +def infer_server_port( + serialized: Optional[Dict[str, Any]], + invocation_params: Optional[Dict[str, Any]], +) -> Optional[int]: + """Extract the server port from available URL sources. + + Only returns a value when the port is explicitly specified in the URL + (not inferred default ports). + """ + url = _extract_url(serialized, invocation_params) + if url is None: + return None + + parsed = urlparse(url) + return parsed.port # None when port is not explicitly set + + +_logger = logging.getLogger(__name__) + +# Header keys recognised by the W3C Trace Context specification. +_TRACE_HEADER_KEYS = ("traceparent", "tracestate") + +# Common nested attribute names where HTTP / trace headers may reside. +_NESTED_HEADER_KEYS = ( + "headers", + "header", + "http_headers", + "request_headers", + "metadata", + "request", +) + + +def extract_trace_headers(container: Any) -> Optional[Dict[str, str]]: + """Extract W3C trace context headers from a container. + + Looks for traceparent/tracestate at the top level and in common + nested locations (headers, metadata, request, etc.). + """ + container_dict = _as_dict(container) + if container_dict is None: + return None + + # 1. Check top-level keys + found: Dict[str, str] = {} + for key in _TRACE_HEADER_KEYS: + value = container_dict.get(key) + if isinstance(value, str) and value: + found[key] = value + + if found: + return found + + # 2. Check nested containers + for nested_key in _NESTED_HEADER_KEYS: + nested = _as_dict(container_dict.get(nested_key)) + if nested is not None: + for key in _TRACE_HEADER_KEYS: + value = nested.get(key) + if isinstance(value, str) and value: + found[key] = value + if found: + return found + + return None + + +@contextmanager +def propagated_context( + headers: Optional[Mapping[str, str]], +) -> Iterator[None]: + """Temporarily adopt an upstream trace context extracted from W3C headers. + + Uses OpenTelemetry's extract() to deserialize W3C trace context, + then attaches it for the duration of the context manager. + """ + if not headers: + yield + return + + token = None + try: + ctx = extract(headers) + token = attach(ctx) + except Exception: # noqa: BLE001 + _logger.debug( + "Failed to extract/attach propagation context", exc_info=True + ) + + try: + yield + finally: + if token is not None: + try: + detach(token) + except Exception: # noqa: BLE001 + _logger.debug( + "Failed to detach propagation context", exc_info=True + ) + + +def extract_propagation_context( + metadata: Optional[Dict[str, Any]], + inputs: Any, + kwargs: Dict[str, Any], +) -> Optional[Dict[str, str]]: + """Try to extract W3C trace headers from callback arguments. + + Checks metadata, inputs, and kwargs in order. + """ + for source in (metadata, inputs, kwargs): + if source is not None: + headers = extract_trace_headers(source) + if headers: + return headers + return None diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_content_recording.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_content_recording.py new file mode 100644 index 0000000000..df7864e355 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_content_recording.py @@ -0,0 +1,263 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from opentelemetry.instrumentation._semconv import ( + _OpenTelemetrySemanticConventionStability, +) +from opentelemetry.instrumentation.langchain.content_recording import ( + ContentPolicy, + should_record_messages, + should_record_retriever_content, + should_record_system_instructions, + should_record_tool_content, +) +from opentelemetry.util.genai.types import ContentCapturingMode + + +@pytest.fixture(autouse=True) +def _reset_semconv_stability(monkeypatch): + """Reset semconv stability cache so each test can set its own env vars.""" + orig_initialized = _OpenTelemetrySemanticConventionStability._initialized + orig_mapping = _OpenTelemetrySemanticConventionStability._OTEL_SEMCONV_STABILITY_SIGNAL_MAPPING.copy() + + _OpenTelemetrySemanticConventionStability._initialized = False + _OpenTelemetrySemanticConventionStability._OTEL_SEMCONV_STABILITY_SIGNAL_MAPPING = {} + + monkeypatch.delenv("OTEL_SEMCONV_STABILITY_OPT_IN", raising=False) + monkeypatch.delenv( + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", raising=False + ) + monkeypatch.delenv("OTEL_INSTRUMENTATION_GENAI_EMIT_EVENT", raising=False) + + yield + + _OpenTelemetrySemanticConventionStability._initialized = orig_initialized + _OpenTelemetrySemanticConventionStability._OTEL_SEMCONV_STABILITY_SIGNAL_MAPPING = orig_mapping + + +def _enter_experimental(monkeypatch, capture_mode): + """Set env vars for experimental mode and re-initialize stability.""" + monkeypatch.setenv( + "OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental" + ) + monkeypatch.setenv( + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", capture_mode + ) + _OpenTelemetrySemanticConventionStability._initialize() + + +# --------------------------------------------------------------------------- +# ContentPolicy – experimental mode with each ContentCapturingMode +# --------------------------------------------------------------------------- + + +class TestContentPolicySpanOnly: + """SPAN_ONLY: content on spans, no events.""" + + def test_should_record_content_on_spans(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_ONLY") + assert ContentPolicy().should_record_content_on_spans is True + + def test_should_emit_events(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_ONLY") + assert ContentPolicy().should_emit_events is False + + def test_record_content(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_ONLY") + assert ContentPolicy().record_content is True + + def test_mode(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_ONLY") + assert ContentPolicy().mode == ContentCapturingMode.SPAN_ONLY + + +class TestContentPolicyEventOnly: + """EVENT_ONLY: events enabled without duplicating content on spans.""" + + def test_should_record_content_on_spans(self, monkeypatch): + _enter_experimental(monkeypatch, "EVENT_ONLY") + assert ContentPolicy().should_record_content_on_spans is False + + def test_should_emit_events(self, monkeypatch): + _enter_experimental(monkeypatch, "EVENT_ONLY") + assert ContentPolicy().should_emit_events is True + + def test_record_content(self, monkeypatch): + _enter_experimental(monkeypatch, "EVENT_ONLY") + assert ContentPolicy().record_content is True + + def test_mode(self, monkeypatch): + _enter_experimental(monkeypatch, "EVENT_ONLY") + assert ContentPolicy().mode == ContentCapturingMode.EVENT_ONLY + + +class TestContentPolicySpanAndEvent: + """SPAN_AND_EVENT: both spans and events active.""" + + def test_should_record_content_on_spans(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_AND_EVENT") + assert ContentPolicy().should_record_content_on_spans is True + + def test_should_emit_events(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_AND_EVENT") + assert ContentPolicy().should_emit_events is True + + def test_record_content(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_AND_EVENT") + assert ContentPolicy().record_content is True + + def test_mode(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_AND_EVENT") + assert ContentPolicy().mode == ContentCapturingMode.SPAN_AND_EVENT + + +class TestContentPolicyNoContent: + """NO_CONTENT: nothing recorded.""" + + def test_should_record_content_on_spans(self, monkeypatch): + _enter_experimental(monkeypatch, "NO_CONTENT") + assert ContentPolicy().should_record_content_on_spans is False + + def test_should_emit_events(self, monkeypatch): + _enter_experimental(monkeypatch, "NO_CONTENT") + assert ContentPolicy().should_emit_events is False + + def test_record_content(self, monkeypatch): + _enter_experimental(monkeypatch, "NO_CONTENT") + assert ContentPolicy().record_content is False + + def test_mode(self, monkeypatch): + _enter_experimental(monkeypatch, "NO_CONTENT") + assert ContentPolicy().mode == ContentCapturingMode.NO_CONTENT + + +class TestContentPolicyRecordContentCombined: + """record_content is True when either spans or events are enabled.""" + + @pytest.mark.parametrize( + "capture_mode, expected", + [ + ("SPAN_ONLY", True), + ("EVENT_ONLY", True), + ("SPAN_AND_EVENT", True), + ("NO_CONTENT", False), + ], + ) + def test_record_content(self, monkeypatch, capture_mode, expected): + _enter_experimental(monkeypatch, capture_mode) + assert ContentPolicy().record_content is expected + + +# --------------------------------------------------------------------------- +# ContentPolicy – outside experimental mode +# --------------------------------------------------------------------------- + + +class TestContentPolicyNonExperimental: + """Without experimental opt-in everything is disabled.""" + + def test_should_record_content_on_spans(self): + _OpenTelemetrySemanticConventionStability._initialize() + assert ContentPolicy().should_record_content_on_spans is False + + def test_should_emit_events(self): + _OpenTelemetrySemanticConventionStability._initialize() + assert ContentPolicy().should_emit_events is False + + def test_record_content(self): + _OpenTelemetrySemanticConventionStability._initialize() + assert ContentPolicy().record_content is False + + def test_mode_is_no_content(self): + _OpenTelemetrySemanticConventionStability._initialize() + assert ContentPolicy().mode == ContentCapturingMode.NO_CONTENT + + +# --------------------------------------------------------------------------- +# Helper functions – delegates to policy.should_record_content_on_spans +# --------------------------------------------------------------------------- + + +class _StubPolicy: + """Minimal stand-in for ContentPolicy with a fixed boolean.""" + + def __init__(self, value: bool): + self.should_record_content_on_spans = value + + +class TestShouldRecordMessages: + def test_true_when_policy_enabled(self): + assert should_record_messages(_StubPolicy(True)) is True + + def test_false_when_policy_disabled(self): + assert should_record_messages(_StubPolicy(False)) is False + + +class TestShouldRecordToolContent: + def test_true_when_policy_enabled(self): + assert should_record_tool_content(_StubPolicy(True)) is True + + def test_false_when_policy_disabled(self): + assert should_record_tool_content(_StubPolicy(False)) is False + + +class TestShouldRecordRetrieverContent: + def test_true_when_policy_enabled(self): + assert should_record_retriever_content(_StubPolicy(True)) is True + + def test_false_when_policy_disabled(self): + assert should_record_retriever_content(_StubPolicy(False)) is False + + +class TestShouldRecordSystemInstructions: + def test_true_when_policy_enabled(self): + assert should_record_system_instructions(_StubPolicy(True)) is True + + def test_false_when_policy_disabled(self): + assert should_record_system_instructions(_StubPolicy(False)) is False + + +# --------------------------------------------------------------------------- +# Helper functions – integration with real ContentPolicy via env vars +# --------------------------------------------------------------------------- + + +class TestHelperFunctionsIntegration: + """Verify helpers produce correct results with a real ContentPolicy.""" + + def test_all_helpers_true_with_span_only(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_ONLY") + policy = ContentPolicy() + assert should_record_messages(policy) is True + assert should_record_tool_content(policy) is True + assert should_record_retriever_content(policy) is True + assert should_record_system_instructions(policy) is True + + def test_all_helpers_false_with_no_content(self, monkeypatch): + _enter_experimental(monkeypatch, "NO_CONTENT") + policy = ContentPolicy() + assert should_record_messages(policy) is False + assert should_record_tool_content(policy) is False + assert should_record_retriever_content(policy) is False + assert should_record_system_instructions(policy) is False + + def test_all_helpers_false_outside_experimental(self): + _OpenTelemetrySemanticConventionStability._initialize() + policy = ContentPolicy() + assert should_record_messages(policy) is False + assert should_record_tool_content(policy) is False + assert should_record_retriever_content(policy) is False + assert should_record_system_instructions(policy) is False diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_event_emitter.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_event_emitter.py new file mode 100644 index 0000000000..301cafe272 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_event_emitter.py @@ -0,0 +1,115 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from opentelemetry.instrumentation.langchain.event_emitter import EventEmitter + + +def _make_policy(*, should_emit_events: bool, record_content: bool): + policy = mock.MagicMock() + policy.should_emit_events = should_emit_events + policy.record_content = record_content + return policy + + +def _make_emitter(): + emitter = EventEmitter() + emitter._logger = mock.MagicMock() + return emitter + + +def test_emits_tool_call_event_with_content(monkeypatch): + policy = _make_policy(should_emit_events=True, record_content=True) + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.event_emitter.get_content_policy", + lambda: policy, + ) + emitter = _make_emitter() + + emitter.emit_tool_call_event( + mock.MagicMock(), + "calculator", + '{"x": 1}', + "call_123", + ) + + record = emitter._logger.emit.call_args.args[0] + assert record.event_name == "gen_ai.tool.call" + assert record.body == { + "name": "calculator", + "id": "call_123", + "arguments": '{"x": 1}', + } + + +def test_redacts_tool_result_when_content_recording_disabled(monkeypatch): + policy = _make_policy(should_emit_events=True, record_content=False) + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.event_emitter.get_content_policy", + lambda: policy, + ) + emitter = _make_emitter() + + emitter.emit_tool_result_event( + mock.MagicMock(), + "calculator", + '{"result": 2}', + ) + + record = emitter._logger.emit.call_args.args[0] + assert record.event_name == "gen_ai.tool.result" + assert record.body == { + "name": "calculator", + "result": "[redacted]", + } + + +def test_skips_agent_event_when_disabled(monkeypatch): + policy = _make_policy(should_emit_events=False, record_content=True) + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.event_emitter.get_content_policy", + lambda: policy, + ) + emitter = _make_emitter() + + emitter.emit_agent_start_event( + mock.MagicMock(), + "planner", + '[{"content": "hi"}]', + ) + + emitter._logger.emit.assert_not_called() + + +def test_emits_retriever_result_event(monkeypatch): + policy = _make_policy(should_emit_events=True, record_content=True) + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.event_emitter.get_content_policy", + lambda: policy, + ) + emitter = _make_emitter() + + emitter.emit_retriever_result_event( + mock.MagicMock(), + "vector_store", + '[{"metadata": {"source": "a.txt"}}]', + ) + + record = emitter._logger.emit.call_args.args[0] + assert record.event_name == "gen_ai.retriever.result" + assert record.body == { + "name": "vector_store", + "documents": '[{"metadata": {"source": "a.txt"}}]', + } diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_operation_mapping.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_operation_mapping.py new file mode 100644 index 0000000000..63ee5b6fac --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_operation_mapping.py @@ -0,0 +1,336 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import uuid4 + +from opentelemetry.instrumentation.langchain.operation_mapping import ( + OperationName, + classify_chain_run, + resolve_agent_name, + should_ignore_chain, +) + +# --------------------------------------------------------------------------- +# classify_chain_run +# --------------------------------------------------------------------------- + + +class TestClassifyChainRunAgentDetection: + """Agent signals → invoke_agent.""" + + def test_otel_agent_span_true(self): + result = classify_chain_run( + serialized={}, + metadata={"otel_agent_span": True}, + kwargs={}, + ) + assert result == OperationName.INVOKE_AGENT + + def test_metadata_agent_name(self): + result = classify_chain_run( + serialized={}, + metadata={"agent_name": "my-agent"}, + kwargs={}, + ) + assert result == OperationName.INVOKE_AGENT + + def test_metadata_agent_type(self): + result = classify_chain_run( + serialized={}, + metadata={"agent_type": "react"}, + kwargs={}, + ) + assert result == OperationName.INVOKE_AGENT + + def test_langgraph_agent_node(self): + result = classify_chain_run( + serialized={}, + metadata={"langgraph_node": "researcher"}, + kwargs={}, + ) + assert result == OperationName.INVOKE_AGENT + + def test_agent_signals_override_workflow(self): + """Agent signals take priority over workflow heuristics.""" + result = classify_chain_run( + serialized={"name": "LangGraph"}, + metadata={"otel_agent_span": True}, + kwargs={}, + parent_run_id=None, + ) + assert result == OperationName.INVOKE_AGENT + + +class TestClassifyChainRunWorkflowDetection: + """Workflow signals → invoke_workflow.""" + + def test_top_level_langgraph_by_name(self): + result = classify_chain_run( + serialized={"name": "LangGraph"}, + metadata={}, + kwargs={}, + parent_run_id=None, + ) + assert result == OperationName.INVOKE_WORKFLOW + + def test_top_level_langgraph_by_graph_id(self): + result = classify_chain_run( + serialized={"name": "other", "graph": {"id": "LangGraph-abc"}}, + metadata={}, + kwargs={}, + parent_run_id=None, + ) + assert result == OperationName.INVOKE_WORKFLOW + + def test_otel_workflow_span_true(self): + result = classify_chain_run( + serialized={}, + metadata={"otel_workflow_span": True}, + kwargs={}, + parent_run_id=None, + ) + assert result == OperationName.INVOKE_WORKFLOW + + def test_not_workflow_when_has_parent(self): + """LangGraph name alone is not enough when there is a parent run.""" + result = classify_chain_run( + serialized={"name": "LangGraph"}, + metadata={}, + kwargs={}, + parent_run_id=uuid4(), + ) + assert result is None + + +class TestClassifyChainRunSuppression: + """Chains that should be suppressed (return None).""" + + def test_start_node_suppressed(self): + result = classify_chain_run( + serialized={}, + metadata={"langgraph_node": "__start__"}, + kwargs={}, + ) + assert result is None + + def test_middleware_prefix_suppressed(self): + result = classify_chain_run( + serialized={"name": "Middleware.auth"}, + metadata={"langgraph_node": "Middleware.auth"}, + kwargs={"name": "Middleware.auth"}, + ) + assert result is None + + def test_otel_trace_false(self): + result = classify_chain_run( + serialized={}, + metadata={"otel_trace": False}, + kwargs={}, + ) + assert result is None + + def test_otel_agent_span_false_no_other_signals(self): + result = classify_chain_run( + serialized={}, + metadata={"otel_agent_span": False}, + kwargs={}, + ) + assert result is None + + def test_unclassified_generic_chain(self): + result = classify_chain_run( + serialized={"name": "RunnableSequence"}, + metadata={}, + kwargs={}, + parent_run_id=uuid4(), + ) + assert result is None + + +# --------------------------------------------------------------------------- +# should_ignore_chain +# --------------------------------------------------------------------------- + + +class TestShouldIgnoreChain: + """Suppression logic for known noise chains.""" + + def test_ignores_start_node(self): + assert should_ignore_chain( + metadata={"langgraph_node": "__start__"}, + agent_name=None, + parent_run_id=None, + kwargs={}, + ) + + def test_ignores_middleware_agent_name(self): + assert should_ignore_chain( + metadata={}, + agent_name="Middleware.something", + parent_run_id=None, + kwargs={}, + ) + + def test_ignores_middleware_in_kwargs_name(self): + assert should_ignore_chain( + metadata={}, + agent_name=None, + parent_run_id=None, + kwargs={"name": "Middleware.guard"}, + ) + + def test_ignores_otel_trace_false(self): + assert should_ignore_chain( + metadata={"otel_trace": False}, + agent_name=None, + parent_run_id=None, + kwargs={}, + ) + + def test_ignores_otel_agent_span_false_no_signals(self): + assert should_ignore_chain( + metadata={"otel_agent_span": False}, + agent_name=None, + parent_run_id=None, + kwargs={}, + ) + + def test_does_not_ignore_otel_agent_span_false_with_agent_name(self): + """otel_agent_span=False is overridden when agent_name is present.""" + assert not should_ignore_chain( + metadata={"otel_agent_span": False, "agent_name": "planner"}, + agent_name="planner", + parent_run_id=None, + kwargs={}, + ) + + def test_does_not_ignore_normal_agent_node(self): + assert not should_ignore_chain( + metadata={"langgraph_node": "researcher"}, + agent_name="researcher", + parent_run_id=uuid4(), + kwargs={}, + ) + + +# --------------------------------------------------------------------------- +# resolve_agent_name +# --------------------------------------------------------------------------- + + +class TestResolveAgentName: + """Best-effort agent name resolution from callback arguments.""" + + def test_from_metadata_agent_name(self): + assert ( + resolve_agent_name( + serialized={}, + metadata={"agent_name": "planner"}, + kwargs={}, + ) + == "planner" + ) + + def test_from_kwargs_name(self): + assert ( + resolve_agent_name( + serialized={}, + metadata={}, + kwargs={"name": "tool-caller"}, + ) + == "tool-caller" + ) + + def test_from_serialized_name(self): + assert ( + resolve_agent_name( + serialized={"name": "MyAgent"}, + metadata={}, + kwargs={}, + ) + == "MyAgent" + ) + + def test_from_langgraph_node(self): + assert ( + resolve_agent_name( + serialized={}, + metadata={"langgraph_node": "researcher"}, + kwargs={}, + ) + == "researcher" + ) + + def test_langgraph_start_node_excluded(self): + assert ( + resolve_agent_name( + serialized={}, + metadata={"langgraph_node": "__start__"}, + kwargs={}, + ) + is None + ) + + def test_returns_none_when_nothing_available(self): + assert ( + resolve_agent_name(serialized={}, metadata={}, kwargs={}) is None + ) + + def test_returns_none_with_none_metadata(self): + assert ( + resolve_agent_name(serialized={}, metadata=None, kwargs={}) is None + ) + + def test_priority_metadata_over_kwargs(self): + """metadata agent_name has higher priority than kwargs name.""" + assert ( + resolve_agent_name( + serialized={"name": "serialized"}, + metadata={"agent_name": "meta"}, + kwargs={"name": "kw"}, + ) + == "meta" + ) + + def test_priority_kwargs_over_serialized(self): + assert ( + resolve_agent_name( + serialized={"name": "serialized"}, + metadata={}, + kwargs={"name": "kw"}, + ) + == "kw" + ) + + +# --------------------------------------------------------------------------- +# OperationName constants +# --------------------------------------------------------------------------- + + +class TestOperationNameConstants: + def test_chat(self): + assert OperationName.CHAT == "chat" + + def test_text_completion(self): + assert OperationName.TEXT_COMPLETION == "text_completion" + + def test_invoke_agent(self): + assert OperationName.INVOKE_AGENT == "invoke_agent" + + def test_execute_tool(self): + assert OperationName.EXECUTE_TOOL == "execute_tool" + + def test_invoke_workflow(self): + assert OperationName.INVOKE_WORKFLOW == "invoke_workflow" diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_span_hierarchy.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_span_hierarchy.py new file mode 100644 index 0000000000..484e047617 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_span_hierarchy.py @@ -0,0 +1,299 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for span hierarchy and parent-child resolution in _SpanManager.""" + +from unittest import mock + +import pytest + +from opentelemetry.instrumentation.langchain.semconv_attributes import ( + OP_CHAT, + OP_INVOKE_AGENT, +) +from opentelemetry.instrumentation.langchain.span_manager import _SpanManager +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) +from opentelemetry.trace.status import StatusCode + + +def _make_mock_span(): + span = mock.MagicMock() + span.is_recording.return_value = True + return span + + +def _make_tracer(): + tracer = mock.MagicMock() + tracer.start_span.side_effect = lambda **kwargs: _make_mock_span() + return tracer + + +@pytest.fixture() +def tracer(): + return _make_tracer() + + +@pytest.fixture() +def mgr(tracer): + return _SpanManager(tracer) + + +# ------------------------------------------------------------------ +# resolve_parent_id +# ------------------------------------------------------------------ + + +class TestResolveParentId: + def test_returns_parent_when_parent_exists_in_spans(self, mgr): + """Parent run_id is returned directly when it is not ignored.""" + mgr.start_span( + run_id="parent-1", + name="agent", + operation=OP_INVOKE_AGENT, + ) + assert mgr.resolve_parent_id("parent-1") == "parent-1" + + def test_walks_through_ignored_runs(self, mgr): + """Children of ignored runs are re-parented to the visible ancestor.""" + mgr.start_span( + run_id="grandparent", + name="agent", + operation=OP_INVOKE_AGENT, + ) + # Middle run is ignored; its parent is the grandparent. + mgr.ignore_run("ignored-middle", parent_run_id="grandparent") + + resolved = mgr.resolve_parent_id("ignored-middle") + assert resolved == "grandparent" + + def test_cycle_in_ignored_chain_returns_none(self, mgr): + """A cycle among ignored runs must not loop forever.""" + mgr.ignore_run("a", parent_run_id="b") + mgr.ignore_run("b", parent_run_id="a") + + assert mgr.resolve_parent_id("a") is None + + def test_returns_none_when_parent_not_found(self, mgr): + assert mgr.resolve_parent_id(None) is None + # Unknown non-ignored id is returned as-is (it is "visible"). + assert mgr.resolve_parent_id("nonexistent") == "nonexistent" + + +# ------------------------------------------------------------------ +# Agent stacks +# ------------------------------------------------------------------ + + +class TestAgentStacks: + def test_agent_stacks_track_per_thread(self, mgr): + """Each thread_key gets its own independent agent stack.""" + mgr.start_span( + run_id="agent-t1", + name="agent", + operation=OP_INVOKE_AGENT, + thread_key="thread-1", + ) + mgr.start_span( + run_id="agent-t2", + name="agent", + operation=OP_INVOKE_AGENT, + thread_key="thread-2", + ) + + assert mgr._agent_stack_by_thread["thread-1"] == ["agent-t1"] + assert mgr._agent_stack_by_thread["thread-2"] == ["agent-t2"] + + def test_start_span_adds_to_agent_stack(self, mgr): + """invoke_agent spans are pushed onto the per-thread agent stack.""" + mgr.start_span( + run_id="agent-a", + name="outer", + operation=OP_INVOKE_AGENT, + thread_key="t", + ) + mgr.start_span( + run_id="agent-b", + name="inner", + operation=OP_INVOKE_AGENT, + thread_key="t", + ) + + assert mgr._agent_stack_by_thread["t"] == ["agent-a", "agent-b"] + + def test_end_span_removes_from_agent_stack(self, mgr): + """Ending an invoke_agent span pops it from the agent stack.""" + mgr.start_span( + run_id="agent-x", + name="agent", + operation=OP_INVOKE_AGENT, + thread_key="tk", + ) + assert "tk" in mgr._agent_stack_by_thread + + mgr.end_span("agent-x") + + # Stack should be cleaned up entirely when empty. + assert "tk" not in mgr._agent_stack_by_thread + + +# ------------------------------------------------------------------ +# Goto routing +# ------------------------------------------------------------------ + + +class TestGotoRouting: + def test_push_pop_lifo(self, mgr): + """push/pop follows LIFO order.""" + mgr.push_goto_parent("t1", "parent-a") + mgr.push_goto_parent("t1", "parent-b") + + assert mgr.pop_goto_parent("t1") == "parent-b" + assert mgr.pop_goto_parent("t1") == "parent-a" + + def test_pop_empty_returns_none(self, mgr): + assert mgr.pop_goto_parent("nonexistent-thread") is None + + def test_cleanup_of_empty_stacks(self, mgr): + """Once the last goto parent is popped the thread key is removed.""" + mgr.push_goto_parent("t1", "p1") + mgr.pop_goto_parent("t1") + + assert "t1" not in mgr._goto_parent_stack + + +# ------------------------------------------------------------------ +# accumulate_usage_to_parent +# ------------------------------------------------------------------ + + +class TestAccumulateUsageToParent: + def test_accumulates_on_nearest_agent_parent(self, mgr): + """Token counts propagate to the nearest invoke_agent ancestor.""" + agent_rec = mgr.start_span( + run_id="agent", + name="agent", + operation=OP_INVOKE_AGENT, + ) + chat_rec = mgr.start_span( + run_id="chat", + name="chat gpt-4o", + operation=OP_CHAT, + parent_run_id="agent", + ) + + mgr.accumulate_usage_to_parent( + chat_rec, input_tokens=10, output_tokens=5 + ) + + agent_span = agent_rec.span + agent_span.set_attribute.assert_any_call( + GenAI.GEN_AI_USAGE_INPUT_TOKENS, 10 + ) + agent_span.set_attribute.assert_any_call( + GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, 5 + ) + + def test_noop_when_no_agent_parent(self, mgr): + """No error when the parent chain has no invoke_agent span.""" + chat_rec = mgr.start_span( + run_id="chat-orphan", + name="chat", + operation=OP_CHAT, + ) + # Should not raise. + mgr.accumulate_usage_to_parent( + chat_rec, input_tokens=1, output_tokens=2 + ) + + def test_handles_none_token_values(self, mgr): + """Both tokens None → early return, no side-effects.""" + agent_rec = mgr.start_span( + run_id="agent", + name="agent", + operation=OP_INVOKE_AGENT, + ) + chat_rec = mgr.start_span( + run_id="chat", + name="chat", + operation=OP_CHAT, + parent_run_id="agent", + ) + + # Reset call tracking after start_span's own set_attribute calls. + agent_rec.span.set_attribute.reset_mock() + + mgr.accumulate_usage_to_parent( + chat_rec, input_tokens=None, output_tokens=None + ) + + # No token attributes should have been set on the agent span. + agent_rec.span.set_attribute.assert_not_called() + + +# ------------------------------------------------------------------ +# start_span / end_span +# ------------------------------------------------------------------ + + +class TestStartEndSpan: + def test_creates_span_with_correct_parent_context(self, mgr, tracer): + """start_span passes the parent span's context to the tracer.""" + mgr.start_span( + run_id="parent", + name="parent-agent", + operation=OP_INVOKE_AGENT, + ) + + mgr.start_span( + run_id="child", + name="child-chat", + operation=OP_CHAT, + parent_run_id="parent", + ) + + # The second start_span call should pass a non-None context. + calls = tracer.start_span.call_args_list + assert len(calls) == 2 + child_call_kwargs = calls[1][1] + assert child_call_kwargs["context"] is not None + + def test_end_span_with_error_sets_attributes(self, mgr): + """end_span records error type and ERROR status on the span.""" + rec = mgr.start_span( + run_id="err-run", + name="failing", + operation=OP_CHAT, + ) + span = rec.span + + mgr.end_span("err-run", error=ValueError("boom")) + + span.set_attribute.assert_any_call("error.type", "ValueError") + span.set_status.assert_called_once() + status_arg = span.set_status.call_args[0][0] + assert status_arg.status_code == StatusCode.ERROR + + def test_end_span_removes_record(self, mgr): + """After end_span the run_id is no longer tracked.""" + mgr.start_span( + run_id="temp", + name="temp", + operation=OP_CHAT, + ) + assert mgr.get_record("temp") is not None + + mgr.end_span("temp") + assert mgr.get_record("temp") is None diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_usage_propagation.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_usage_propagation.py new file mode 100644 index 0000000000..99b7a5ade3 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_usage_propagation.py @@ -0,0 +1,237 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for token usage accumulation from LLM spans to parent agent spans.""" + +from unittest.mock import MagicMock +from uuid import uuid4 + +from opentelemetry.instrumentation.langchain.semconv_attributes import ( + OP_CHAT, + OP_INVOKE_AGENT, +) +from opentelemetry.instrumentation.langchain.span_manager import ( + SpanRecord, + _SpanManager, +) +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) + +INPUT_TOKENS = GenAI.GEN_AI_USAGE_INPUT_TOKENS +OUTPUT_TOKENS = GenAI.GEN_AI_USAGE_OUTPUT_TOKENS + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_mock_span(): + span = MagicMock() + span.set_attribute = MagicMock() + return span + + +def _make_tracer(): + """Return a mock Tracer whose start_span returns fresh mock spans.""" + tracer = MagicMock() + tracer.start_span = MagicMock(side_effect=lambda **kw: _make_mock_span()) + return tracer + + +def _make_manager(): + tracer = _make_tracer() + return _SpanManager(tracer), tracer + + +def _register_record(mgr, run_id, operation, parent_run_id=None): + """Register a SpanRecord directly in the manager for test isolation.""" + rid = str(run_id) + prid = str(parent_run_id) if parent_run_id is not None else None + span = _make_mock_span() + record = SpanRecord( + run_id=rid, + span=span, + operation=operation, + parent_run_id=prid, + ) + mgr._spans[rid] = record + return record + + +# ------------------------------------------------------------------ +# accumulate_usage_to_parent +# ------------------------------------------------------------------ + + +class TestAccumulateUsageToParent: + """Tests for _SpanManager.accumulate_usage_to_parent.""" + + def test_accumulates_tokens_on_nearest_agent_parent(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + llm_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + llm_rec = _register_record( + mgr, llm_id, OP_CHAT, parent_run_id=agent_id + ) + + mgr.accumulate_usage_to_parent( + llm_rec, input_tokens=10, output_tokens=20 + ) + + agent_rec.span.set_attribute.assert_any_call(INPUT_TOKENS, 10) + agent_rec.span.set_attribute.assert_any_call(OUTPUT_TOKENS, 20) + + def test_accumulates_across_multiple_llm_calls(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + llm1_id = str(uuid4()) + llm2_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + llm1_rec = _register_record( + mgr, llm1_id, OP_CHAT, parent_run_id=agent_id + ) + llm2_rec = _register_record( + mgr, llm2_id, OP_CHAT, parent_run_id=agent_id + ) + + mgr.accumulate_usage_to_parent( + llm1_rec, input_tokens=10, output_tokens=5 + ) + mgr.accumulate_usage_to_parent( + llm2_rec, input_tokens=20, output_tokens=15 + ) + + # After two calls the values should be additive. + agent_rec.span.set_attribute.assert_any_call(INPUT_TOKENS, 30) + agent_rec.span.set_attribute.assert_any_call(OUTPUT_TOKENS, 20) + + def test_skips_non_agent_parents(self): + """Walk up through a non-agent (chat) intermediate to the agent.""" + mgr, _ = _make_manager() + agent_id = str(uuid4()) + chain_id = str(uuid4()) + llm_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + _register_record(mgr, chain_id, OP_CHAT, parent_run_id=agent_id) + llm_rec = _register_record( + mgr, llm_id, OP_CHAT, parent_run_id=chain_id + ) + + mgr.accumulate_usage_to_parent( + llm_rec, input_tokens=7, output_tokens=3 + ) + + agent_rec.span.set_attribute.assert_any_call(INPUT_TOKENS, 7) + agent_rec.span.set_attribute.assert_any_call(OUTPUT_TOKENS, 3) + + def test_noop_when_both_tokens_none(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + llm_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + llm_rec = _register_record( + mgr, llm_id, OP_CHAT, parent_run_id=agent_id + ) + + mgr.accumulate_usage_to_parent( + llm_rec, input_tokens=None, output_tokens=None + ) + + agent_rec.span.set_attribute.assert_not_called() + + def test_handles_only_input_tokens(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + llm_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + llm_rec = _register_record( + mgr, llm_id, OP_CHAT, parent_run_id=agent_id + ) + + mgr.accumulate_usage_to_parent( + llm_rec, input_tokens=42, output_tokens=None + ) + + agent_rec.span.set_attribute.assert_called_once_with(INPUT_TOKENS, 42) + + def test_handles_only_output_tokens(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + llm_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + llm_rec = _register_record( + mgr, llm_id, OP_CHAT, parent_run_id=agent_id + ) + + mgr.accumulate_usage_to_parent( + llm_rec, input_tokens=None, output_tokens=99 + ) + + agent_rec.span.set_attribute.assert_called_once_with(OUTPUT_TOKENS, 99) + + +# ------------------------------------------------------------------ +# accumulate_llm_usage_to_agent +# ------------------------------------------------------------------ + + +class TestAccumulateLlmUsageToAgent: + """Tests for _SpanManager.accumulate_llm_usage_to_agent.""" + + def test_resolves_through_ignored_runs(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + ignored_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + mgr.ignore_run(ignored_id, parent_run_id=agent_id) + + mgr.accumulate_llm_usage_to_agent( + parent_run_id=ignored_id, input_tokens=15, output_tokens=25 + ) + + agent_rec.span.set_attribute.assert_any_call(INPUT_TOKENS, 15) + agent_rec.span.set_attribute.assert_any_call(OUTPUT_TOKENS, 25) + + def test_noop_when_parent_run_id_is_none(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + + mgr.accumulate_llm_usage_to_agent( + parent_run_id=None, input_tokens=10, output_tokens=20 + ) + + agent_rec.span.set_attribute.assert_not_called() + + def test_noop_when_no_agent_in_chain(self): + mgr, _ = _make_manager() + chat_id = str(uuid4()) + + chat_rec = _register_record(mgr, chat_id, OP_CHAT) + + mgr.accumulate_llm_usage_to_agent( + parent_run_id=chat_id, input_tokens=10, output_tokens=20 + ) + + chat_rec.span.set_attribute.assert_not_called() diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_utils.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_utils.py new file mode 100644 index 0000000000..5a166306dc --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_utils.py @@ -0,0 +1,323 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from opentelemetry.instrumentation.langchain.utils import ( + infer_provider_name, + infer_server_address, + infer_server_port, +) + +# --------------------------------------------------------------------------- +# infer_provider_name +# --------------------------------------------------------------------------- + + +class TestInferProviderNameFromMetadata: + """Provider resolution via metadata ls_provider field.""" + + @pytest.mark.parametrize( + "ls_provider, expected", + [ + ("openai", "openai"), + ("anthropic", "anthropic"), + ("cohere", "cohere"), + ("ollama", "ollama"), + ], + ) + def test_direct_mapping(self, ls_provider, expected): + metadata = {"ls_provider": ls_provider} + assert infer_provider_name({}, metadata, {}) == expected + + @pytest.mark.parametrize( + "ls_provider", + ["azure", "azure_openai"], + ) + def test_azure_variants(self, ls_provider): + metadata = {"ls_provider": ls_provider} + assert infer_provider_name({}, metadata, {}) == "azure.ai.openai" + + def test_github_maps_to_azure(self): + metadata = {"ls_provider": "github"} + assert infer_provider_name({}, metadata, {}) == "azure.ai.openai" + + @pytest.mark.parametrize( + "ls_provider", + ["amazon_bedrock", "bedrock", "aws_bedrock"], + ) + def test_bedrock_variants(self, ls_provider): + metadata = {"ls_provider": ls_provider} + assert infer_provider_name({}, metadata, {}) == "aws.bedrock" + + def test_google(self): + metadata = {"ls_provider": "google"} + assert infer_provider_name({}, metadata, {}) == "gcp.gen_ai" + + +class TestInferProviderNameFromBaseUrl: + """Provider resolution via base_url in invocation_params.""" + + @pytest.mark.parametrize( + "url, expected", + [ + ("https://my-resource.openai.azure.com/v1", "azure.ai.openai"), + ("https://api.openai.com/v1", "openai"), + ( + "https://bedrock-runtime.us-east-1.amazonaws.com", + "aws.bedrock", + ), + ("https://api.anthropic.com/v1", "anthropic"), + ( + "https://us-central1-aiplatform.googleapis.com", + "gcp.gen_ai", + ), + ], + ) + def test_url_patterns(self, url, expected): + invocation_params = {"base_url": url} + assert infer_provider_name({}, {}, invocation_params) == expected + + def test_azure_keyword_in_url(self): + invocation_params = { + "base_url": "https://custom-azure-endpoint.example.com/v1" + } + assert ( + infer_provider_name({}, {}, invocation_params) == "azure.ai.openai" + ) + + def test_ollama_keyword_in_url(self): + invocation_params = { + "base_url": "https://my-ollama-server.local:11434/api" + } + assert infer_provider_name({}, {}, invocation_params) == "ollama" + + def test_amazonaws_in_url(self): + invocation_params = { + "base_url": "https://runtime.sagemaker.us-west-2.amazonaws.com" + } + assert infer_provider_name({}, {}, invocation_params) == "aws.bedrock" + + def test_openai_com_in_url(self): + invocation_params = {"base_url": "https://api.openai.com/v2/chat"} + assert infer_provider_name({}, {}, invocation_params) == "openai" + + +class TestInferProviderNameFromSerializedClassName: + """Provider resolution via serialized name/id fields.""" + + @pytest.mark.parametrize( + "class_name, expected", + [ + ("ChatOpenAI", "openai"), + ("ChatBedrock", "aws.bedrock"), + ("ChatAnthropic", "anthropic"), + ("ChatGoogleGenerativeAI", "gcp.gen_ai"), + ], + ) + def test_class_name(self, class_name, expected): + serialized = {"name": class_name} + assert infer_provider_name(serialized, {}, {}) == expected + + @pytest.mark.parametrize( + "class_name, expected", + [ + ("ChatOpenAI", "openai"), + ("ChatBedrock", "aws.bedrock"), + ("ChatAnthropic", "anthropic"), + ("ChatGoogleGenerativeAI", "gcp.gen_ai"), + ], + ) + def test_class_name_via_id(self, class_name, expected): + serialized = {"id": ["langchain_openai", "chat_models", class_name]} + assert infer_provider_name(serialized, {}, {}) == expected + + +class TestInferProviderNameFromSerializedKwargs: + """Provider resolution via kwargs in serialized dict.""" + + def test_azure_endpoint_kwarg(self): + serialized = { + "kwargs": {"azure_endpoint": "https://my-model.openai.azure.com/"} + } + assert infer_provider_name(serialized, {}, {}) == "azure.ai.openai" + + +class TestInferProviderNameReturnsNone: + """Returns None when no provider signals are available.""" + + def test_empty_inputs(self): + assert infer_provider_name({}, {}, {}) is None + + def test_none_inputs(self): + assert infer_provider_name({}, None, None) is None + + def test_unrecognized_metadata(self): + metadata = {"ls_provider": "some_unknown_provider"} + assert infer_provider_name({}, metadata, {}) is None + + def test_unrecognized_url(self): + invocation_params = {"base_url": "https://custom-llm.example.com/v1"} + assert infer_provider_name({}, {}, invocation_params) is None + + def test_unrecognized_class_name(self): + serialized = {"name": "ChatCustomLLM"} + assert infer_provider_name(serialized, {}, {}) is None + + +class TestInferProviderNamePriority: + """Metadata takes priority over invocation_params over serialized.""" + + def test_metadata_over_invocation_params(self): + metadata = {"ls_provider": "anthropic"} + invocation_params = {"base_url": "https://api.openai.com/v1"} + assert ( + infer_provider_name({}, metadata, invocation_params) == "anthropic" + ) + + def test_metadata_over_serialized(self): + metadata = {"ls_provider": "anthropic"} + serialized = {"name": "ChatOpenAI"} + assert infer_provider_name(serialized, metadata, {}) == "anthropic" + + def test_invocation_params_over_serialized(self): + invocation_params = {"base_url": "https://api.anthropic.com/v1"} + serialized = {"name": "ChatOpenAI"} + assert ( + infer_provider_name(serialized, {}, invocation_params) + == "anthropic" + ) + + +# --------------------------------------------------------------------------- +# infer_server_address +# --------------------------------------------------------------------------- + + +class TestInferServerAddress: + """Extract hostname from various URL sources.""" + + def test_from_invocation_params_base_url(self): + invocation_params = {"base_url": "https://api.openai.com/v1"} + assert infer_server_address({}, invocation_params) == "api.openai.com" + + def test_from_serialized_openai_api_base(self): + serialized = { + "kwargs": {"openai_api_base": "https://my-model.openai.azure.com/"} + } + assert ( + infer_server_address(serialized, {}) == "my-model.openai.azure.com" + ) + + def test_from_serialized_azure_endpoint(self): + serialized = { + "kwargs": { + "azure_endpoint": "https://my-resource.openai.azure.com/" + } + } + assert ( + infer_server_address(serialized, {}) + == "my-resource.openai.azure.com" + ) + + def test_returns_none_when_no_url(self): + assert infer_server_address({}, {}) is None + + def test_returns_none_for_empty_inputs(self): + assert infer_server_address({}, None) is None + + def test_returns_none_for_none_serialized_kwargs(self): + serialized = {"kwargs": {}} + assert infer_server_address(serialized, {}) is None + + def test_strips_port_from_hostname(self): + invocation_params = {"base_url": "http://localhost:11434/v1"} + assert infer_server_address({}, invocation_params) == "localhost" + + def test_handles_url_with_path(self): + invocation_params = { + "base_url": "https://api.openai.com/v1/chat/completions" + } + assert infer_server_address({}, invocation_params) == "api.openai.com" + + def test_handles_malformed_url(self): + invocation_params = {"base_url": "not-a-valid-url"} + result = infer_server_address({}, invocation_params) + # Should not raise; either returns None or a best-effort parse + assert result is None or isinstance(result, str) + + def test_handles_empty_string_url(self): + invocation_params = {"base_url": ""} + result = infer_server_address({}, invocation_params) + assert result is None or isinstance(result, str) + + def test_invocation_params_base_url_takes_priority(self): + serialized = { + "kwargs": {"openai_api_base": "https://fallback.example.com/v1"} + } + invocation_params = {"base_url": "https://primary.example.com/v1"} + assert ( + infer_server_address(serialized, invocation_params) + == "primary.example.com" + ) + + +# --------------------------------------------------------------------------- +# infer_server_port +# --------------------------------------------------------------------------- + + +class TestInferServerPort: + """Extract port from URL sources.""" + + def test_explicit_port(self): + invocation_params = {"base_url": "http://localhost:11434/v1"} + assert infer_server_port({}, invocation_params) == 11434 + + def test_no_explicit_port_returns_none(self): + invocation_params = {"base_url": "https://api.openai.com/v1"} + assert infer_server_port({}, invocation_params) is None + + def test_standard_http_port_returned_when_explicit(self): + # urlparse returns port when explicitly specified, even if standard + invocation_params = {"base_url": "http://api.example.com:80/v1"} + assert infer_server_port({}, invocation_params) == 80 + + def test_standard_https_port_returned_when_explicit(self): + invocation_params = {"base_url": "https://api.example.com:443/v1"} + assert infer_server_port({}, invocation_params) == 443 + + def test_custom_port(self): + invocation_params = {"base_url": "https://api.example.com:8443/v1"} + assert infer_server_port({}, invocation_params) == 8443 + + def test_returns_none_when_no_url(self): + assert infer_server_port({}, {}) is None + + def test_returns_none_for_none_inputs(self): + assert infer_server_port({}, None) is None + + def test_port_from_serialized_openai_api_base(self): + serialized = { + "kwargs": {"openai_api_base": "http://localhost:8080/v1"} + } + assert infer_server_port(serialized, {}) == 8080 + + def test_port_from_serialized_azure_endpoint(self): + serialized = { + "kwargs": { + "azure_endpoint": "https://my-resource.openai.azure.com:9090/" + } + } + assert infer_server_port(serialized, {}) == 9090 diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_w3c_propagation.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_w3c_propagation.py new file mode 100644 index 0000000000..d303642f33 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_w3c_propagation.py @@ -0,0 +1,200 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from opentelemetry.instrumentation.langchain.utils import ( + extract_propagation_context, + extract_trace_headers, + propagated_context, +) +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.trace import get_current_span + +# Valid W3C traceparent components for test fixtures. +_TRACE_ID = "0af7651916cd43dd8448eb211c80319c" +_SPAN_ID = "b7ad6b7169203331" +_TRACEPARENT = f"00-{_TRACE_ID}-{_SPAN_ID}-01" +_TRACESTATE = "congo=t61rcWkgMzE" + + +# --------------------------------------------------------------------------- +# extract_trace_headers +# --------------------------------------------------------------------------- + + +class TestExtractTraceHeaders: + """Tests for extract_trace_headers().""" + + def test_top_level_traceparent(self): + container = {"traceparent": _TRACEPARENT} + result = extract_trace_headers(container) + assert result == {"traceparent": _TRACEPARENT} + + def test_top_level_traceparent_and_tracestate(self): + container = {"traceparent": _TRACEPARENT, "tracestate": _TRACESTATE} + result = extract_trace_headers(container) + assert result == { + "traceparent": _TRACEPARENT, + "tracestate": _TRACESTATE, + } + + def test_nested_headers_key(self): + container = {"headers": {"traceparent": _TRACEPARENT}} + result = extract_trace_headers(container) + assert result == {"traceparent": _TRACEPARENT} + + def test_nested_metadata_key(self): + container = { + "metadata": { + "traceparent": _TRACEPARENT, + "tracestate": _TRACESTATE, + } + } + result = extract_trace_headers(container) + assert result == { + "traceparent": _TRACEPARENT, + "tracestate": _TRACESTATE, + } + + def test_nested_request_headers_key(self): + container = {"request_headers": {"traceparent": _TRACEPARENT}} + result = extract_trace_headers(container) + assert result == {"traceparent": _TRACEPARENT} + + def test_empty_container_returns_none(self): + assert extract_trace_headers({}) is None + + def test_no_trace_headers_returns_none(self): + container = {"foo": "bar", "headers": {"content-type": "text/plain"}} + assert extract_trace_headers(container) is None + + def test_non_dict_container_returns_none(self): + assert extract_trace_headers(None) is None + assert extract_trace_headers("string") is None + assert extract_trace_headers(42) is None + assert extract_trace_headers(["traceparent", _TRACEPARENT]) is None + + def test_empty_string_traceparent_ignored(self): + container = {"traceparent": ""} + assert extract_trace_headers(container) is None + + def test_top_level_takes_precedence_over_nested(self): + other_traceparent = ( + "00-11111111111111111111111111111111-2222222222222222-01" + ) + container = { + "traceparent": _TRACEPARENT, + "headers": {"traceparent": other_traceparent}, + } + result = extract_trace_headers(container) + assert result["traceparent"] == _TRACEPARENT + + +# --------------------------------------------------------------------------- +# propagated_context +# --------------------------------------------------------------------------- + + +class TestPropagatedContext: + """Tests for the propagated_context() context manager.""" + + def test_noop_when_headers_is_none(self): + with propagated_context(None): + span = get_current_span() + assert not span.get_span_context().is_valid + + def test_noop_when_headers_is_empty(self): + with propagated_context({}): + span = get_current_span() + assert not span.get_span_context().is_valid + + def test_attaches_and_detaches_valid_traceparent(self): + provider = TracerProvider() + tracer = provider.get_tracer("test") + + with tracer.start_as_current_span("outer"): + outer_ctx = get_current_span().get_span_context() + + headers = {"traceparent": _TRACEPARENT} + with propagated_context(headers): + inner_ctx = get_current_span().get_span_context() + # The propagated context should carry the injected trace id. + assert format(inner_ctx.trace_id, "032x") == _TRACE_ID + + # After exiting, we should be back to the outer span. + restored_ctx = get_current_span().get_span_context() + assert restored_ctx.trace_id == outer_ctx.trace_id + + provider.shutdown() + + def test_invalid_traceparent_does_not_crash(self): + headers = {"traceparent": "not-a-valid-traceparent"} + with propagated_context(headers): + # Should execute without raising; span context may be invalid. + span = get_current_span() + assert span is not None + + def test_malformed_traceparent_does_not_crash(self): + headers = {"traceparent": "00-short-bad-01"} + with propagated_context(headers): + span = get_current_span() + assert span is not None + + +# --------------------------------------------------------------------------- +# extract_propagation_context +# --------------------------------------------------------------------------- + + +class TestExtractPropagationContext: + """Tests for extract_propagation_context().""" + + def test_finds_headers_in_metadata(self): + metadata = {"traceparent": _TRACEPARENT} + result = extract_propagation_context(metadata, {}, {}) + assert result == {"traceparent": _TRACEPARENT} + + def test_falls_back_to_inputs(self): + inputs = {"traceparent": _TRACEPARENT} + result = extract_propagation_context({}, inputs, {}) + assert result == {"traceparent": _TRACEPARENT} + + def test_falls_back_to_kwargs(self): + kwargs = {"traceparent": _TRACEPARENT} + result = extract_propagation_context({}, {}, kwargs) + assert result == {"traceparent": _TRACEPARENT} + + def test_returns_none_when_no_source_has_headers(self): + result = extract_propagation_context({}, {}, {}) + assert result is None + + def test_metadata_takes_precedence_over_inputs(self): + other_traceparent = ( + "00-11111111111111111111111111111111-2222222222222222-01" + ) + metadata = {"traceparent": _TRACEPARENT} + inputs = {"traceparent": other_traceparent} + result = extract_propagation_context(metadata, inputs, {}) + assert result["traceparent"] == _TRACEPARENT + + def test_none_sources_are_skipped(self): + kwargs = {"traceparent": _TRACEPARENT} + result = extract_propagation_context(None, None, kwargs) + assert result == {"traceparent": _TRACEPARENT} + + def test_non_dict_inputs_skipped(self): + result = extract_propagation_context( + None, "not a dict", {"traceparent": _TRACEPARENT} + ) + assert result == {"traceparent": _TRACEPARENT}