diff --git a/altk_evolve/sync/phoenix_sync.py b/altk_evolve/sync/phoenix_sync.py index 65aaeda7..cc4d7d9b 100644 --- a/altk_evolve/sync/phoenix_sync.py +++ b/altk_evolve/sync/phoenix_sync.py @@ -105,6 +105,36 @@ def _get_processed_span_ids(self) -> set[str]: except NamespaceNotFoundException: return set() + def _get_processed_trace_ids(self) -> set[str]: + """Get trace_ids that have already been processed.""" + try: + entities = self.client.search_entities( + namespace_id=self.namespace_id, + filters={"type": "trajectory"}, + limit=10000, + ) + return {str(e.metadata.get("trace_id")) for e in entities if e.metadata and e.metadata.get("trace_id")} + except NamespaceNotFoundException: + return set() + + def _select_representative_spans(self, spans: list[dict]) -> list[dict]: + """For each trace, return only the last span by timestamp. + + Each LLM call in an agent run receives the full accumulated message history + as input, so the last span chronologically contains the most complete + conversation and is the right unit for guideline generation. Processing + all spans would analyse the same steps multiple times with growing context. + """ + by_trace: dict[str, list[dict]] = {} + for span in spans: + trace_id = span.get("context", {}).get("trace_id") + if trace_id: + by_trace.setdefault(trace_id, []).append(span) + return [ + max(trace_spans, key=lambda s: s.get("start_time") or "") + for trace_spans in by_trace.values() + ] + def _format_payload_summary(self, payload: Any) -> str: """Format a payload summary for secure logging (avoid PII).""" type_name = type(payload).__name__ @@ -180,6 +210,7 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]: role = msg.get("message.role") or msg.get("role") content = msg.get("message.content") or msg.get("content") tool_calls = msg.get("message.tool_calls") or msg.get("tool_calls") + tool_call_id = msg.get("message.tool_call_id") or msg.get("tool_call_id") if role: mapped_msg = { @@ -190,6 +221,8 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]: } if tool_calls: mapped_msg["tool_calls"] = tool_calls + if tool_call_id: + mapped_msg["tool_call_id"] = tool_call_id messages.append(mapped_msg) # Handle Output/Completion from OpenInference @@ -227,6 +260,7 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]: role = msg.get("message.role") or msg.get("role") content = msg.get("message.content") or msg.get("content") tool_calls = msg.get("message.tool_calls") or msg.get("tool_calls") + tool_call_id = msg.get("message.tool_call_id") or msg.get("tool_call_id") if role: mapped_msg = { @@ -237,11 +271,85 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]: } if tool_calls: mapped_msg["tool_calls"] = tool_calls + if tool_call_id: + mapped_msg["tool_call_id"] = tool_call_id messages.append(mapped_msg) if messages: return messages + # Indexed OpenInference format from Phoenix REST API: + # llm.input_messages.{i}.message.role / .content / .tool_calls.{j}.* / .tool_call_id + input_indices: set[int] = set() + output_indices: set[int] = set() + for key in attrs: + if key.startswith("llm.input_messages."): + parts = key.split(".") + if len(parts) >= 3 and parts[2].isdigit(): + input_indices.add(int(parts[2])) + elif key.startswith("llm.output_messages."): + parts = key.split(".") + if len(parts) >= 3 and parts[2].isdigit(): + output_indices.add(int(parts[2])) + + for i in sorted(input_indices): + role = attrs.get(f"llm.input_messages.{i}.message.role") + content = attrs.get(f"llm.input_messages.{i}.message.content") + tool_call_id = attrs.get(f"llm.input_messages.{i}.message.tool_call_id") + + tc_indices: set[int] = set() + prefix = f"llm.input_messages.{i}.message.tool_calls." + for key in attrs: + if key.startswith(prefix): + parts = key[len(prefix):].split(".") + if parts and parts[0].isdigit(): + tc_indices.add(int(parts[0])) + tool_calls = [] + for j in sorted(tc_indices): + tc_prefix = f"llm.input_messages.{i}.message.tool_calls.{j}.tool_call." + tool_calls.append({ + "tool_call.id": attrs.get(f"{tc_prefix}id", ""), + "tool_call.function.name": attrs.get(f"{tc_prefix}function.name", ""), + "tool_call.function.arguments": attrs.get(f"{tc_prefix}function.arguments", "{}"), + }) + + if role: + mapped_msg_in: dict = {"index": i, "type": "prompt", "role": role, "content": content} + if tool_calls: + mapped_msg_in["tool_calls"] = tool_calls + if tool_call_id: + mapped_msg_in["tool_call_id"] = tool_call_id + messages.append(mapped_msg_in) + + for i in sorted(output_indices): + role = attrs.get(f"llm.output_messages.{i}.message.role") + content = attrs.get(f"llm.output_messages.{i}.message.content") + + tc_indices_out: set[int] = set() + prefix_out = f"llm.output_messages.{i}.message.tool_calls." + for key in attrs: + if key.startswith(prefix_out): + parts = key[len(prefix_out):].split(".") + if parts and parts[0].isdigit(): + tc_indices_out.add(int(parts[0])) + tool_calls_out = [] + for j in sorted(tc_indices_out): + tc_prefix_out = f"llm.output_messages.{i}.message.tool_calls.{j}.tool_call." + tool_calls_out.append({ + "tool_call.id": attrs.get(f"{tc_prefix_out}id", ""), + "tool_call.function.name": attrs.get(f"{tc_prefix_out}function.name", ""), + "tool_call.function.arguments": attrs.get(f"{tc_prefix_out}function.arguments", "{}"), + }) + + if role: + mapped_msg_out: dict = {"index": i, "type": "completion", "role": role, "content": content} + if tool_calls_out: + mapped_msg_out["tool_calls"] = tool_calls_out + messages.append(mapped_msg_out) + + if messages: + return messages + # Fallback to GenAI semantic conventions (original code) # Extract prompt messages prompt_indices = set() @@ -355,6 +463,115 @@ def _convert_to_openai_format(self, content: Any, role: str) -> dict: content_text = "\n\n".join(text_parts) if text_parts else "" return {"role": role, "content": content_text} + def _extract_tools_from_span(self, span: dict) -> list[dict] | None: + """Extract tool definitions from a span's attributes in OpenAI tools format. + + Tries three attribute conventions in order: + 1. llm.invocation_parameters (LiteLLM GenAI convention) — JSON dict with a "tools" key + 2. llm.tools as a JSON array where each item is {"tool.json_schema": ""} + 3. Indexed llm.tools.{i}.tool.json_schema keys (Phoenix REST API flat format) + """ + attrs = span.get("attributes") or {} + + invocation_params = attrs.get("llm.invocation_parameters") + if invocation_params: + try: + params = json.loads(invocation_params) if isinstance(invocation_params, str) else invocation_params + if isinstance(params, dict): + tools = params.get("tools") + if isinstance(tools, list) and tools: + return tools + except (json.JSONDecodeError, Exception): + pass + + tools_attr = attrs.get("llm.tools") + if tools_attr is not None: + try: + tools = json.loads(tools_attr) if isinstance(tools_attr, str) else tools_attr + if isinstance(tools, list) and tools: + # OpenInference list format: each item is {"tool.json_schema": ""} + openai_tools = [] + for item in tools: + if isinstance(item, dict): + schema_str = item.get("tool.json_schema") + if schema_str is not None: + try: + schema = json.loads(schema_str) if isinstance(schema_str, str) else schema_str + openai_tools.append(schema) + continue + except (json.JSONDecodeError, Exception): + pass + openai_tools.append(item) + if openai_tools: + return openai_tools + except (json.JSONDecodeError, Exception): + pass + + # Indexed flat format from Phoenix REST API: llm.tools.{i}.tool.json_schema + tool_indices: set[int] = set() + for key in attrs: + if key.startswith("llm.tools."): + parts = key.split(".") + if len(parts) >= 3 and parts[2].isdigit(): + tool_indices.add(int(parts[2])) + + if tool_indices: + tools = [] + for i in sorted(tool_indices): + json_schema_str = attrs.get(f"llm.tools.{i}.tool.json_schema") + if json_schema_str: + try: + schema = json.loads(json_schema_str) if isinstance(json_schema_str, str) else json_schema_str + tools.append(schema) + continue + except (json.JSONDecodeError, Exception): + pass + # Fall back to building from name/description/parameters parts + name = attrs.get(f"llm.tools.{i}.tool.name") + if not name: + continue + tool: dict = {"type": "function", "function": {"name": name}} + description = attrs.get(f"llm.tools.{i}.tool.description") + if description: + tool["function"]["description"] = description + json_schema = attrs.get(f"llm.tools.{i}.tool.json_schema") + if json_schema: + try: + schema = json.loads(json_schema) if isinstance(json_schema, str) else json_schema + tool["function"]["parameters"] = schema + except (json.JSONDecodeError, Exception): + pass + tools.append(tool) + if tools: + return tools + + return None + + def _convert_openinference_tool_calls(self, tool_calls: list) -> list: + """Convert OpenInference tool_calls to OpenAI format. + + OpenInference: {"tool_call.function.name": ..., "tool_call.id": ..., "tool_call.function.arguments": ...} + OpenAI: {"id": ..., "type": "function", "function": {"name": ..., "arguments": ...}} + """ + result = [] + for tc in tool_calls: + if not isinstance(tc, dict): + continue + name = tc.get("tool_call.function.name") + if name is not None: + arguments = tc.get("tool_call.function.arguments", "{}") + result.append({ + "id": tc.get("tool_call.id", ""), + "type": "function", + "function": { + "name": name, + "arguments": arguments if isinstance(arguments, str) else json.dumps(arguments), + }, + }) + elif "id" in tc or "function" in tc: + result.append(tc) + return result + def _extract_trajectory(self, span: dict) -> dict: """Extract a complete trajectory from a span.""" attrs = span.get("attributes") or {} @@ -365,9 +582,12 @@ def _extract_trajectory(self, span: dict) -> dict: for msg in messages: role = msg["role"] content = msg["content"] + raw_tool_calls = msg.get("tool_calls") + tool_call_id = msg.get("tool_call_id") converted = self._convert_to_openai_format(content, role) if converted.get("role") == "tool" and "tool_results" in converted: + # Anthropic content-block format — tool_call_id already embedded for result in converted["tool_results"]: openai_messages.append( { @@ -376,15 +596,33 @@ def _extract_trajectory(self, span: dict) -> dict: "content": result["content"], } ) + elif role == "tool" and tool_call_id: + # OpenInference format — tool_call_id extracted from span attribute + openai_messages.append( + { + "role": "tool", + "tool_call_id": tool_call_id, + "content": converted.get("content", ""), + } + ) + elif role == "assistant" and raw_tool_calls: + # OpenInference format — convert tool_calls from OpenInference to OpenAI + openai_tool_calls = self._convert_openinference_tool_calls(raw_tool_calls) + if openai_tool_calls: + converted["tool_calls"] = openai_tool_calls + if converted.get("content") in (None, "None", ""): + converted.pop("content", None) + openai_messages.append(converted) else: openai_messages.append(converted) return { "trace_id": span["context"]["trace_id"], "span_id": span["context"]["span_id"], - "model": attrs.get("gen_ai.request.model", "unknown"), + "model": attrs.get("gen_ai.request.model") or attrs.get("llm.model_name", "unknown"), "timestamp": span.get("start_time"), "messages": openai_messages, + "tools": self._extract_tools_from_span(span), "usage": { "prompt_tokens": next( ( @@ -532,38 +770,50 @@ def sync( spans = self._fetch_spans(limit) logger.info(f"Fetched {len(spans)} spans from Phoenix") - # Get already processed span IDs - processed_ids = self._get_processed_span_ids() - logger.info(f"Found {len(processed_ids)} already processed spans") + # Dedup at trace level: each agent run produces multiple spans (one per LLM call), + # each accumulating the full message history. Processing all would analyse the same + # steps multiple times. We skip any trace already stored, then pick one + # representative span per remaining trace (the latest, which has the complete context). + processed_trace_ids = self._get_processed_trace_ids() + logger.info(f"Found {len(processed_trace_ids)} already processed traces") processed = 0 skipped = 0 guidelines_generated = 0 errors = [] + # First pass: filter to candidate LLM spans, skipping already-processed traces + candidates = [] + skipped_trace_ids: set[str] = set() for span in spans: - # Filter to LLM request spans - accept any span with prompt attributes - # if span.get("name") != "litellm_request": - # continue - - # Filter errors if requested if not include_errors and span.get("status_code") == "ERROR": continue - # Check if already processed - span_id = span.get("context", {}).get("span_id") - if span_id in processed_ids: - skipped += 1 - continue - - # Only include spans with actual messages or GenAI/LLM prompt attributes attrs = span.get("attributes") or {} has_gen_ai = any(k.startswith("gen_ai.prompt.") for k in attrs) - has_llm_msgs = "llm.input_messages" in attrs or "input.value" in attrs - + has_llm_msgs = ( + "llm.input_messages" in attrs + or "input.value" in attrs + or any(k.startswith("llm.input_messages.") for k in attrs) + ) if not (has_gen_ai or has_llm_msgs): continue + trace_id = span.get("context", {}).get("trace_id") + if trace_id in processed_trace_ids: + skipped_trace_ids.add(trace_id) + continue + + candidates.append(span) + + skipped = len(skipped_trace_ids) + + # Second pass: one representative span per trace (latest start_time = most complete context) + representative_spans = self._select_representative_spans(candidates) + logger.info(f"Selected {len(representative_spans)} representative spans from {len(candidates)} candidates") + + for span in representative_spans: + span_id = span.get("context", {}).get("span_id") try: trajectory = self._extract_trajectory(span) trajectory = self._clean_trajectory(trajectory) @@ -572,7 +822,7 @@ def sync( guidelines_count = self._process_trajectory(trajectory) processed += 1 guidelines_generated += guidelines_count - logger.info(f"Processed span {span_id[:12]}... - generated {guidelines_count} guidelines") + logger.info(f"Processed trace {trajectory['trace_id'][:12]}... - generated {guidelines_count} guidelines") except Exception as e: error_msg = f"Error processing span {span_id}: {e}" logger.exception(error_msg) diff --git a/tests/unit/test_phoenix_sync.py b/tests/unit/test_phoenix_sync.py index 7fa3d47b..9d1e2a4d 100644 --- a/tests/unit/test_phoenix_sync.py +++ b/tests/unit/test_phoenix_sync.py @@ -454,9 +454,9 @@ def test_sync_skips_already_processed(self, mock_generate_guidelines, mock_urlop mock_response.__exit__ = Mock(return_value=False) mock_urlopen.return_value = mock_response - # Mock that this span was already processed + # Mock that this trace was already processed mock_entity = MagicMock() - mock_entity.metadata = {"span_id": "already_processed"} + mock_entity.metadata = {"span_id": "already_processed", "trace_id": "t1"} phoenix_sync.client.search_entities.return_value = [mock_entity] result = phoenix_sync.sync(limit=10) @@ -639,9 +639,9 @@ def test_sync_returns_correct_counts(self, mock_generate_guidelines, mock_urlope mock_response.__exit__ = Mock(return_value=False) mock_urlopen.return_value = mock_response - # old_span was already processed + # old_span / trace t2 was already processed mock_entity = MagicMock() - mock_entity.metadata = {"span_id": "old_span"} + mock_entity.metadata = {"span_id": "old_span", "trace_id": "t2"} phoenix_sync.client.search_entities.return_value = [mock_entity] # Create mock Guideline object with required attributes mock_guideline = MagicMock()