4343from unittest import TestCase
4444from unittest .mock import patch
4545
46- from langchain_core .outputs import ChatGeneration , LLMResult
4746from langchain_core .messages import AIMessage
47+ from langchain_core .outputs import ChatGeneration , LLMResult
4848
4949from opentelemetry import baggage
50- from opentelemetry import context as otel_context
5150from opentelemetry .instrumentation .langchain .callback_handler import (
5251 OpenTelemetryLangChainCallbackHandler ,
5352)
6362)
6463from opentelemetry .util .genai .handler import TelemetryHandler
6564
66-
6765# ---------------------------------------------------------------------------
6866# Shared helpers
6967# ---------------------------------------------------------------------------
7068
69+
7170def _make_serialized (name : str ) -> dict :
7271 """Minimal serialized dict that on_chain_start / on_chat_model_start expect."""
7372 return {"name" : name }
@@ -78,18 +77,26 @@ def _make_llm_result(content: str = "hello") -> LLMResult:
7877 msg = AIMessage (content = content )
7978 gen = ChatGeneration (message = msg , text = content )
8079 gen .generation_info = {"finish_reason" : "stop" }
81- return LLMResult (generations = [[gen ]], llm_output = {"model_name" : "gpt-3.5-turbo" })
80+ return LLMResult (
81+ generations = [[gen ]], llm_output = {"model_name" : "gpt-3.5-turbo" }
82+ )
8283
8384
8485def _make_chat_invocation_params (model_name : str = "gpt-3.5-turbo" ) -> dict :
8586 """kwargs dict that on_chat_model_start receives for a ChatOpenAI call."""
86- return {"invocation_params" : {"model_name" : model_name , "params" : {"model_name" : model_name }}}
87+ return {
88+ "invocation_params" : {
89+ "model_name" : model_name ,
90+ "params" : {"model_name" : model_name },
91+ }
92+ }
8793
8894
8995# ---------------------------------------------------------------------------
9096# Base test class
9197# ---------------------------------------------------------------------------
9298
99+
93100class _CallbackHandlerTestBase (TestCase ):
94101 def setUp (self ) -> None :
95102 self .span_exporter = InMemorySpanExporter ()
@@ -113,6 +120,7 @@ def _spans_by_kind(self, kind: SpanKind):
113120# Tests
114121# ---------------------------------------------------------------------------
115122
123+
116124class TestWorkflowSpanCreation (_CallbackHandlerTestBase ):
117125 """Verify that a workflow span is created for top-level chains."""
118126
@@ -243,7 +251,9 @@ def test_workflow_name_from_metadata_override(self) -> None:
243251
244252 internal_spans = self ._spans_by_kind (SpanKind .INTERNAL )
245253 self .assertEqual (len (internal_spans ), 1 )
246- self .assertEqual (internal_spans [0 ].name , "invoke_workflow my_custom_wf" )
254+ self .assertEqual (
255+ internal_spans [0 ].name , "invoke_workflow my_custom_wf"
256+ )
247257
248258 client_spans = self ._spans_by_kind (SpanKind .CLIENT )
249259 self .assertEqual (len (client_spans ), 1 )
@@ -257,7 +267,11 @@ class TestCSANotLeakedToBaggage(_CallbackHandlerTestBase):
257267 """Verify that gen_ai.workflow.name is NOT written to W3C Baggage by default."""
258268
259269 def test_csa_not_leaked_to_baggage (self ) -> None :
260- env = {k : v for k , v in os .environ .items () if k != "OTEL_PYTHON_GENAI_CAPTURE_BAGGAGE" }
270+ env = {
271+ k : v
272+ for k , v in os .environ .items ()
273+ if k != "OTEL_PYTHON_GENAI_CAPTURE_BAGGAGE"
274+ }
261275 with patch .dict (os .environ , env , clear = True ):
262276 chain_run_id = uuid .uuid4 ()
263277
@@ -458,9 +472,13 @@ def test_llm_error_inside_workflow_records_error_on_llm_span(self) -> None:
458472 # Workflow span still finishes (not in error state)
459473 internal_spans = self ._spans_by_kind (SpanKind .INTERNAL )
460474 self .assertEqual (len (internal_spans ), 1 )
461- self .assertNotEqual (internal_spans [0 ].status .status_code , StatusCode .ERROR )
475+ self .assertNotEqual (
476+ internal_spans [0 ].status .status_code , StatusCode .ERROR
477+ )
462478
463- def test_llm_error_inside_workflow_llm_span_is_child_of_workflow (self ) -> None :
479+ def test_llm_error_inside_workflow_llm_span_is_child_of_workflow (
480+ self ,
481+ ) -> None :
464482 chain_run_id = uuid .uuid4 ()
465483 llm_run_id = uuid .uuid4 ()
466484
@@ -520,7 +538,9 @@ def test_name_falls_back_to_id_list(self) -> None:
520538 run_id = chain_run_id ,
521539 parent_run_id = None ,
522540 )
523- self .handler .on_chain_end (outputs = {}, run_id = chain_run_id , parent_run_id = None )
541+ self .handler .on_chain_end (
542+ outputs = {}, run_id = chain_run_id , parent_run_id = None
543+ )
524544
525545 internal_spans = self ._spans_by_kind (SpanKind .INTERNAL )
526546 self .assertEqual (len (internal_spans ), 1 )
@@ -537,7 +557,9 @@ def test_name_falls_back_to_langgraph_node(self) -> None:
537557 parent_run_id = None ,
538558 metadata = {"langgraph_node" : "my_node" },
539559 )
540- self .handler .on_chain_end (outputs = {}, run_id = chain_run_id , parent_run_id = None )
560+ self .handler .on_chain_end (
561+ outputs = {}, run_id = chain_run_id , parent_run_id = None
562+ )
541563
542564 internal_spans = self ._spans_by_kind (SpanKind .INTERNAL )
543565 self .assertEqual (len (internal_spans ), 1 )
@@ -552,7 +574,9 @@ def test_name_defaults_to_chain_when_nothing_provided(self) -> None:
552574 run_id = chain_run_id ,
553575 parent_run_id = None ,
554576 )
555- self .handler .on_chain_end (outputs = {}, run_id = chain_run_id , parent_run_id = None )
577+ self .handler .on_chain_end (
578+ outputs = {}, run_id = chain_run_id , parent_run_id = None
579+ )
556580
557581 internal_spans = self ._spans_by_kind (SpanKind .INTERNAL )
558582 self .assertEqual (len (internal_spans ), 1 )
0 commit comments