@@ -335,3 +335,48 @@ def test_fields(self, mock_trace):
335335 self .assertEqual (
336336 AwsXRayPropagatorTest .XRAY_PROPAGATOR .fields , inject_fields
337337 )
338+
339+ def test_extract_trace_state_from_context (self ):
340+ """Test that extract properly propagates the trace state extracted by other propagators."""
341+ context_with_extracted = AwsXRayPropagatorTest .XRAY_PROPAGATOR .extract (
342+ CaseInsensitiveDict (
343+ {
344+ TRACE_HEADER_KEY : "Root=1-8a3c60f7-d188f8fa79d48a391a778fa6;Parent=53995c3f42cd8ad8;Sampled=0" ,
345+ }
346+ ),
347+ context = set_span_in_context (
348+ trace_api .NonRecordingSpan (
349+ SpanContext (
350+ int (TRACE_ID_BASE16 , 16 ),
351+ int (SPAN_ID_BASE16 , 16 ),
352+ True ,
353+ DEFAULT_TRACE_OPTIONS ,
354+ TraceState ([("foo" , "bar" ), ("baz" , "qux" )]),
355+ )
356+ )
357+ ),
358+ )
359+
360+ extracted_span_context = get_nested_span_context (
361+ context_with_extracted
362+ )
363+ expected_trace_state = TraceState ([("foo" , "bar" ), ("baz" , "qux" )])
364+
365+ self .assertEqual (
366+ extracted_span_context .trace_state , expected_trace_state
367+ )
368+
369+ def test_extract_no_trace_state_from_context (self ):
370+ """Test that extract defaults to an empty trace state correctly."""
371+ context_with_extracted = AwsXRayPropagatorTest .XRAY_PROPAGATOR .extract (
372+ CaseInsensitiveDict (
373+ {
374+ TRACE_HEADER_KEY : "Root=1-8a3c60f7-d188f8fa79d48a391a778fa6;Parent=53995c3f42cd8ad8;Sampled=0" ,
375+ }
376+ )
377+ )
378+
379+ extracted_span_context = get_nested_span_context (
380+ context_with_extracted
381+ )
382+ self .assertEqual (extracted_span_context .trace_state , TraceState ([]))
0 commit comments