From e7e785f9411dac838062ebe282e5c4d77902b0b9 Mon Sep 17 00:00:00 2001 From: Evelyn Duesterwald Date: Wed, 17 Jun 2026 14:20:49 -0400 Subject: [PATCH] fix(sync): correct Phoenix span extraction for multi-span traces and tool-calling agents MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two independent fixes to `PhoenixSync` that affect guideline generation accuracy. Fix 1 — trace-level dedup and representative span selection: Each agent run emits one Phoenix span per LLM call, with every subsequent span re-including all prior messages (cumulative context window). Processing all spans caused the same conversation steps to be analysed multiple times with growing context, inflating guideline counts and skewing results. The fix deduplicates at trace level (`_get_processed_trace_ids`) and picks a single representative span per trace — the last by `start_time`, which holds the most complete message history (`_select_representative_spans`). Fix 2 — complete message and tool extraction from OpenInference spans: The Phoenix REST API returns OpenInference attributes as flat indexed keys (`llm.input_messages.{i}.message.*`, `llm.tools.{i}.tool.json_schema`) rather than nested lists. The previous extraction code missed `tool_call_id` on tool messages and `tool_calls` on assistant messages, producing incomplete trajectories where tool-calling steps appeared as `content: "None"` with no linkage between assistant calls and tool results. Guidelines generated from such trajectories lacked tool-use context. The fix adds: - `tool_call_id` extraction in both message loops - an indexed-attribute reader for `llm.input_messages.{i}.*` and `llm.output_messages.{i}.*` (including nested `tool_calls.{j}.*`) - `_extract_tools_from_span`: parses `llm.tools.{i}.tool.json_schema` and the list-mode `{"tool.json_schema": "..."}` format into OpenAI tool dicts - `_convert_openinference_tool_calls`: converts OpenInference tool_call dicts to OpenAI format (`tool_call.function.name` → `function.name`, etc.) - `_extract_trajectory`: propagates `tool_calls` and `tool_call_id` from the extracted message dict, choosing the right assembly path per message format Unit tests updated: entity metadata mocks now include `trace_id` to match the new trace-level dedup logic. Co-Authored-By: Claude Sonnet 4.6 --- altk_evolve/sync/phoenix_sync.py | 288 +++++++++++++++++++++++++++++-- tests/unit/test_phoenix_sync.py | 8 +- 2 files changed, 273 insertions(+), 23 deletions(-) diff --git a/altk_evolve/sync/phoenix_sync.py b/altk_evolve/sync/phoenix_sync.py index 65aaeda7..cc4d7d9b 100644 --- a/altk_evolve/sync/phoenix_sync.py +++ b/altk_evolve/sync/phoenix_sync.py @@ -105,6 +105,36 @@ def _get_processed_span_ids(self) -> set[str]: except NamespaceNotFoundException: return set() + def _get_processed_trace_ids(self) -> set[str]: + """Get trace_ids that have already been processed.""" + try: + entities = self.client.search_entities( + namespace_id=self.namespace_id, + filters={"type": "trajectory"}, + limit=10000, + ) + return {str(e.metadata.get("trace_id")) for e in entities if e.metadata and e.metadata.get("trace_id")} + except NamespaceNotFoundException: + return set() + + def _select_representative_spans(self, spans: list[dict]) -> list[dict]: + """For each trace, return only the last span by timestamp. + + Each LLM call in an agent run receives the full accumulated message history + as input, so the last span chronologically contains the most complete + conversation and is the right unit for guideline generation. Processing + all spans would analyse the same steps multiple times with growing context. + """ + by_trace: dict[str, list[dict]] = {} + for span in spans: + trace_id = span.get("context", {}).get("trace_id") + if trace_id: + by_trace.setdefault(trace_id, []).append(span) + return [ + max(trace_spans, key=lambda s: s.get("start_time") or "") + for trace_spans in by_trace.values() + ] + def _format_payload_summary(self, payload: Any) -> str: """Format a payload summary for secure logging (avoid PII).""" type_name = type(payload).__name__ @@ -180,6 +210,7 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]: role = msg.get("message.role") or msg.get("role") content = msg.get("message.content") or msg.get("content") tool_calls = msg.get("message.tool_calls") or msg.get("tool_calls") + tool_call_id = msg.get("message.tool_call_id") or msg.get("tool_call_id") if role: mapped_msg = { @@ -190,6 +221,8 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]: } if tool_calls: mapped_msg["tool_calls"] = tool_calls + if tool_call_id: + mapped_msg["tool_call_id"] = tool_call_id messages.append(mapped_msg) # Handle Output/Completion from OpenInference @@ -227,6 +260,7 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]: role = msg.get("message.role") or msg.get("role") content = msg.get("message.content") or msg.get("content") tool_calls = msg.get("message.tool_calls") or msg.get("tool_calls") + tool_call_id = msg.get("message.tool_call_id") or msg.get("tool_call_id") if role: mapped_msg = { @@ -237,11 +271,85 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]: } if tool_calls: mapped_msg["tool_calls"] = tool_calls + if tool_call_id: + mapped_msg["tool_call_id"] = tool_call_id messages.append(mapped_msg) if messages: return messages + # Indexed OpenInference format from Phoenix REST API: + # llm.input_messages.{i}.message.role / .content / .tool_calls.{j}.* / .tool_call_id + input_indices: set[int] = set() + output_indices: set[int] = set() + for key in attrs: + if key.startswith("llm.input_messages."): + parts = key.split(".") + if len(parts) >= 3 and parts[2].isdigit(): + input_indices.add(int(parts[2])) + elif key.startswith("llm.output_messages."): + parts = key.split(".") + if len(parts) >= 3 and parts[2].isdigit(): + output_indices.add(int(parts[2])) + + for i in sorted(input_indices): + role = attrs.get(f"llm.input_messages.{i}.message.role") + content = attrs.get(f"llm.input_messages.{i}.message.content") + tool_call_id = attrs.get(f"llm.input_messages.{i}.message.tool_call_id") + + tc_indices: set[int] = set() + prefix = f"llm.input_messages.{i}.message.tool_calls." + for key in attrs: + if key.startswith(prefix): + parts = key[len(prefix):].split(".") + if parts and parts[0].isdigit(): + tc_indices.add(int(parts[0])) + tool_calls = [] + for j in sorted(tc_indices): + tc_prefix = f"llm.input_messages.{i}.message.tool_calls.{j}.tool_call." + tool_calls.append({ + "tool_call.id": attrs.get(f"{tc_prefix}id", ""), + "tool_call.function.name": attrs.get(f"{tc_prefix}function.name", ""), + "tool_call.function.arguments": attrs.get(f"{tc_prefix}function.arguments", "{}"), + }) + + if role: + mapped_msg_in: dict = {"index": i, "type": "prompt", "role": role, "content": content} + if tool_calls: + mapped_msg_in["tool_calls"] = tool_calls + if tool_call_id: + mapped_msg_in["tool_call_id"] = tool_call_id + messages.append(mapped_msg_in) + + for i in sorted(output_indices): + role = attrs.get(f"llm.output_messages.{i}.message.role") + content = attrs.get(f"llm.output_messages.{i}.message.content") + + tc_indices_out: set[int] = set() + prefix_out = f"llm.output_messages.{i}.message.tool_calls." + for key in attrs: + if key.startswith(prefix_out): + parts = key[len(prefix_out):].split(".") + if parts and parts[0].isdigit(): + tc_indices_out.add(int(parts[0])) + tool_calls_out = [] + for j in sorted(tc_indices_out): + tc_prefix_out = f"llm.output_messages.{i}.message.tool_calls.{j}.tool_call." + tool_calls_out.append({ + "tool_call.id": attrs.get(f"{tc_prefix_out}id", ""), + "tool_call.function.name": attrs.get(f"{tc_prefix_out}function.name", ""), + "tool_call.function.arguments": attrs.get(f"{tc_prefix_out}function.arguments", "{}"), + }) + + if role: + mapped_msg_out: dict = {"index": i, "type": "completion", "role": role, "content": content} + if tool_calls_out: + mapped_msg_out["tool_calls"] = tool_calls_out + messages.append(mapped_msg_out) + + if messages: + return messages + # Fallback to GenAI semantic conventions (original code) # Extract prompt messages prompt_indices = set() @@ -355,6 +463,115 @@ def _convert_to_openai_format(self, content: Any, role: str) -> dict: content_text = "\n\n".join(text_parts) if text_parts else "" return {"role": role, "content": content_text} + def _extract_tools_from_span(self, span: dict) -> list[dict] | None: + """Extract tool definitions from a span's attributes in OpenAI tools format. + + Tries three attribute conventions in order: + 1. llm.invocation_parameters (LiteLLM GenAI convention) — JSON dict with a "tools" key + 2. llm.tools as a JSON array where each item is {"tool.json_schema": ""} + 3. Indexed llm.tools.{i}.tool.json_schema keys (Phoenix REST API flat format) + """ + attrs = span.get("attributes") or {} + + invocation_params = attrs.get("llm.invocation_parameters") + if invocation_params: + try: + params = json.loads(invocation_params) if isinstance(invocation_params, str) else invocation_params + if isinstance(params, dict): + tools = params.get("tools") + if isinstance(tools, list) and tools: + return tools + except (json.JSONDecodeError, Exception): + pass + + tools_attr = attrs.get("llm.tools") + if tools_attr is not None: + try: + tools = json.loads(tools_attr) if isinstance(tools_attr, str) else tools_attr + if isinstance(tools, list) and tools: + # OpenInference list format: each item is {"tool.json_schema": ""} + openai_tools = [] + for item in tools: + if isinstance(item, dict): + schema_str = item.get("tool.json_schema") + if schema_str is not None: + try: + schema = json.loads(schema_str) if isinstance(schema_str, str) else schema_str + openai_tools.append(schema) + continue + except (json.JSONDecodeError, Exception): + pass + openai_tools.append(item) + if openai_tools: + return openai_tools + except (json.JSONDecodeError, Exception): + pass + + # Indexed flat format from Phoenix REST API: llm.tools.{i}.tool.json_schema + tool_indices: set[int] = set() + for key in attrs: + if key.startswith("llm.tools."): + parts = key.split(".") + if len(parts) >= 3 and parts[2].isdigit(): + tool_indices.add(int(parts[2])) + + if tool_indices: + tools = [] + for i in sorted(tool_indices): + json_schema_str = attrs.get(f"llm.tools.{i}.tool.json_schema") + if json_schema_str: + try: + schema = json.loads(json_schema_str) if isinstance(json_schema_str, str) else json_schema_str + tools.append(schema) + continue + except (json.JSONDecodeError, Exception): + pass + # Fall back to building from name/description/parameters parts + name = attrs.get(f"llm.tools.{i}.tool.name") + if not name: + continue + tool: dict = {"type": "function", "function": {"name": name}} + description = attrs.get(f"llm.tools.{i}.tool.description") + if description: + tool["function"]["description"] = description + json_schema = attrs.get(f"llm.tools.{i}.tool.json_schema") + if json_schema: + try: + schema = json.loads(json_schema) if isinstance(json_schema, str) else json_schema + tool["function"]["parameters"] = schema + except (json.JSONDecodeError, Exception): + pass + tools.append(tool) + if tools: + return tools + + return None + + def _convert_openinference_tool_calls(self, tool_calls: list) -> list: + """Convert OpenInference tool_calls to OpenAI format. + + OpenInference: {"tool_call.function.name": ..., "tool_call.id": ..., "tool_call.function.arguments": ...} + OpenAI: {"id": ..., "type": "function", "function": {"name": ..., "arguments": ...}} + """ + result = [] + for tc in tool_calls: + if not isinstance(tc, dict): + continue + name = tc.get("tool_call.function.name") + if name is not None: + arguments = tc.get("tool_call.function.arguments", "{}") + result.append({ + "id": tc.get("tool_call.id", ""), + "type": "function", + "function": { + "name": name, + "arguments": arguments if isinstance(arguments, str) else json.dumps(arguments), + }, + }) + elif "id" in tc or "function" in tc: + result.append(tc) + return result + def _extract_trajectory(self, span: dict) -> dict: """Extract a complete trajectory from a span.""" attrs = span.get("attributes") or {} @@ -365,9 +582,12 @@ def _extract_trajectory(self, span: dict) -> dict: for msg in messages: role = msg["role"] content = msg["content"] + raw_tool_calls = msg.get("tool_calls") + tool_call_id = msg.get("tool_call_id") converted = self._convert_to_openai_format(content, role) if converted.get("role") == "tool" and "tool_results" in converted: + # Anthropic content-block format — tool_call_id already embedded for result in converted["tool_results"]: openai_messages.append( { @@ -376,15 +596,33 @@ def _extract_trajectory(self, span: dict) -> dict: "content": result["content"], } ) + elif role == "tool" and tool_call_id: + # OpenInference format — tool_call_id extracted from span attribute + openai_messages.append( + { + "role": "tool", + "tool_call_id": tool_call_id, + "content": converted.get("content", ""), + } + ) + elif role == "assistant" and raw_tool_calls: + # OpenInference format — convert tool_calls from OpenInference to OpenAI + openai_tool_calls = self._convert_openinference_tool_calls(raw_tool_calls) + if openai_tool_calls: + converted["tool_calls"] = openai_tool_calls + if converted.get("content") in (None, "None", ""): + converted.pop("content", None) + openai_messages.append(converted) else: openai_messages.append(converted) return { "trace_id": span["context"]["trace_id"], "span_id": span["context"]["span_id"], - "model": attrs.get("gen_ai.request.model", "unknown"), + "model": attrs.get("gen_ai.request.model") or attrs.get("llm.model_name", "unknown"), "timestamp": span.get("start_time"), "messages": openai_messages, + "tools": self._extract_tools_from_span(span), "usage": { "prompt_tokens": next( ( @@ -532,38 +770,50 @@ def sync( spans = self._fetch_spans(limit) logger.info(f"Fetched {len(spans)} spans from Phoenix") - # Get already processed span IDs - processed_ids = self._get_processed_span_ids() - logger.info(f"Found {len(processed_ids)} already processed spans") + # Dedup at trace level: each agent run produces multiple spans (one per LLM call), + # each accumulating the full message history. Processing all would analyse the same + # steps multiple times. We skip any trace already stored, then pick one + # representative span per remaining trace (the latest, which has the complete context). + processed_trace_ids = self._get_processed_trace_ids() + logger.info(f"Found {len(processed_trace_ids)} already processed traces") processed = 0 skipped = 0 guidelines_generated = 0 errors = [] + # First pass: filter to candidate LLM spans, skipping already-processed traces + candidates = [] + skipped_trace_ids: set[str] = set() for span in spans: - # Filter to LLM request spans - accept any span with prompt attributes - # if span.get("name") != "litellm_request": - # continue - - # Filter errors if requested if not include_errors and span.get("status_code") == "ERROR": continue - # Check if already processed - span_id = span.get("context", {}).get("span_id") - if span_id in processed_ids: - skipped += 1 - continue - - # Only include spans with actual messages or GenAI/LLM prompt attributes attrs = span.get("attributes") or {} has_gen_ai = any(k.startswith("gen_ai.prompt.") for k in attrs) - has_llm_msgs = "llm.input_messages" in attrs or "input.value" in attrs - + has_llm_msgs = ( + "llm.input_messages" in attrs + or "input.value" in attrs + or any(k.startswith("llm.input_messages.") for k in attrs) + ) if not (has_gen_ai or has_llm_msgs): continue + trace_id = span.get("context", {}).get("trace_id") + if trace_id in processed_trace_ids: + skipped_trace_ids.add(trace_id) + continue + + candidates.append(span) + + skipped = len(skipped_trace_ids) + + # Second pass: one representative span per trace (latest start_time = most complete context) + representative_spans = self._select_representative_spans(candidates) + logger.info(f"Selected {len(representative_spans)} representative spans from {len(candidates)} candidates") + + for span in representative_spans: + span_id = span.get("context", {}).get("span_id") try: trajectory = self._extract_trajectory(span) trajectory = self._clean_trajectory(trajectory) @@ -572,7 +822,7 @@ def sync( guidelines_count = self._process_trajectory(trajectory) processed += 1 guidelines_generated += guidelines_count - logger.info(f"Processed span {span_id[:12]}... - generated {guidelines_count} guidelines") + logger.info(f"Processed trace {trajectory['trace_id'][:12]}... - generated {guidelines_count} guidelines") except Exception as e: error_msg = f"Error processing span {span_id}: {e}" logger.exception(error_msg) diff --git a/tests/unit/test_phoenix_sync.py b/tests/unit/test_phoenix_sync.py index 7fa3d47b..9d1e2a4d 100644 --- a/tests/unit/test_phoenix_sync.py +++ b/tests/unit/test_phoenix_sync.py @@ -454,9 +454,9 @@ def test_sync_skips_already_processed(self, mock_generate_guidelines, mock_urlop mock_response.__exit__ = Mock(return_value=False) mock_urlopen.return_value = mock_response - # Mock that this span was already processed + # Mock that this trace was already processed mock_entity = MagicMock() - mock_entity.metadata = {"span_id": "already_processed"} + mock_entity.metadata = {"span_id": "already_processed", "trace_id": "t1"} phoenix_sync.client.search_entities.return_value = [mock_entity] result = phoenix_sync.sync(limit=10) @@ -639,9 +639,9 @@ def test_sync_returns_correct_counts(self, mock_generate_guidelines, mock_urlope mock_response.__exit__ = Mock(return_value=False) mock_urlopen.return_value = mock_response - # old_span was already processed + # old_span / trace t2 was already processed mock_entity = MagicMock() - mock_entity.metadata = {"span_id": "old_span"} + mock_entity.metadata = {"span_id": "old_span", "trace_id": "t2"} phoenix_sync.client.search_entities.return_value = [mock_entity] # Create mock Guideline object with required attributes mock_guideline = MagicMock()