|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +from __future__ import annotations |
| 16 | + |
15 | 17 | import threading |
16 | 18 | from concurrent.futures import ( # pylint: disable=no-name-in-module; TODO #4199 |
| 19 | + Future, |
17 | 20 | ThreadPoolExecutor, |
18 | 21 | ) |
19 | 22 | from typing import List |
@@ -66,7 +69,7 @@ def test_trace_context_propagation_in_thread_pool_with_multiple_workers( |
66 | 69 | executor = ThreadPoolExecutor(max_workers=max_workers) |
67 | 70 |
|
68 | 71 | expected_span_contexts: List[trace.SpanContext] = [] |
69 | | - futures_list = [] |
| 72 | + futures_list: List[Future[trace.SpanContext]] = [] |
70 | 73 | for num in range(max_workers): |
71 | 74 | with self._tracer.start_as_current_span(f"trace_{num}") as span: |
72 | 75 | expected_span_context = span.get_span_context() |
@@ -125,23 +128,23 @@ def fake_func(self): |
125 | 128 | def get_current_span_context_for_test() -> trace.SpanContext: |
126 | 129 | return trace.get_current_span().get_span_context() |
127 | 130 |
|
128 | | - def print_square(self, num): |
| 131 | + def print_square(self, num: int | float) -> int | float: |
129 | 132 | with self._tracer.start_as_current_span("square"): |
130 | 133 | return num * num |
131 | 134 |
|
132 | | - def print_cube(self, num): |
| 135 | + def print_cube(self, num: int | float) -> int | float: |
133 | 136 | with self._tracer.start_as_current_span("cube"): |
134 | 137 | return num * num * num |
135 | 138 |
|
136 | | - def print_square_with_thread(self, num): |
| 139 | + def print_square_with_thread(self, num: int | float) -> int | float: |
137 | 140 | with self._tracer.start_as_current_span("square"): |
138 | 141 | cube_thread = threading.Thread(target=self.print_cube, args=(10,)) |
139 | 142 |
|
140 | 143 | cube_thread.start() |
141 | 144 | cube_thread.join() |
142 | 145 | return num * num |
143 | 146 |
|
144 | | - def calculate(self, num): |
| 147 | + def calculate(self, num: int | float) -> None: |
145 | 148 | with self._tracer.start_as_current_span("calculate"): |
146 | 149 | square_thread = threading.Thread( |
147 | 150 | target=self.print_square, args=(num,) |
@@ -294,3 +297,48 @@ def test_threadpool_with_valid_context_token(self, mock_detach: MagicMock): |
294 | 297 | future = executor.submit(self.get_current_span_context_for_test) |
295 | 298 | future.result() |
296 | 299 | mock_detach.assert_called_once() |
| 300 | + |
| 301 | + def test_threading_run_without_start(self): |
| 302 | + square_thread = threading.Thread(target=self.print_square, args=(10,)) |
| 303 | + with self._tracer.start_as_current_span("root"): |
| 304 | + square_thread.run() |
| 305 | + |
| 306 | + spans = self.memory_exporter.get_finished_spans() |
| 307 | + self.assertEqual(len(spans), 2) |
| 308 | + root_span = next(span for span in spans if span.name == "root") |
| 309 | + self.assertIsNotNone(root_span) |
| 310 | + self.assertIsNone(root_span.parent) |
| 311 | + square_span = next(span for span in spans if span.name == "square") |
| 312 | + self.assertIsNotNone(square_span) |
| 313 | + self.assertIs(square_span.parent, root_span.get_span_context()) |
| 314 | + |
| 315 | + def test_threading_run_with_custom_run(self): |
| 316 | + _tracer = self._tracer |
| 317 | + |
| 318 | + class ThreadWithCustomRun(threading.Thread): |
| 319 | + def run(self): |
| 320 | + # don't call super().run() on purpose |
| 321 | + # Thread.run() cannot be called twice |
| 322 | + with _tracer.start_as_current_span("square"): |
| 323 | + pass |
| 324 | + |
| 325 | + square_thread = ThreadWithCustomRun( |
| 326 | + target=self.print_square, args=(10,) |
| 327 | + ) |
| 328 | + with self._tracer.start_as_current_span("run_1"): |
| 329 | + square_thread.run() |
| 330 | + with self._tracer.start_as_current_span("run_2"): |
| 331 | + square_thread.run() |
| 332 | + |
| 333 | + spans = self.memory_exporter.get_finished_spans() |
| 334 | + self.assertEqual(len(spans), 4) |
| 335 | + run_1_span = next(span for span in spans if span.name == "run_1") |
| 336 | + run_2_span = next(span for span in spans if span.name == "run_2") |
| 337 | + square_spans = [span for span in spans if span.name == "square"] |
| 338 | + square_spans.sort(key=lambda x: x.start_time or 0) |
| 339 | + run_1_child_span = square_spans[0] |
| 340 | + run_2_child_span = square_spans[1] |
| 341 | + self.assertIs(run_1_child_span.parent, run_1_span.get_span_context()) |
| 342 | + self.assertIs(run_2_child_span.parent, run_2_span.get_span_context()) |
| 343 | + self.assertIsNone(run_1_span.parent) |
| 344 | + self.assertIsNone(run_2_span.parent) |
0 commit comments