Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 269 additions & 19 deletions altk_evolve/sync/phoenix_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,36 @@ def _get_processed_span_ids(self) -> set[str]:
except NamespaceNotFoundException:
return set()

def _get_processed_trace_ids(self) -> set[str]:
"""Get trace_ids that have already been processed."""
try:
entities = self.client.search_entities(
namespace_id=self.namespace_id,
filters={"type": "trajectory"},
limit=10000,
)
return {str(e.metadata.get("trace_id")) for e in entities if e.metadata and e.metadata.get("trace_id")}
except NamespaceNotFoundException:
return set()

def _select_representative_spans(self, spans: list[dict]) -> list[dict]:
"""For each trace, return only the last span by timestamp.

Each LLM call in an agent run receives the full accumulated message history
as input, so the last span chronologically contains the most complete
conversation and is the right unit for guideline generation. Processing
all spans would analyse the same steps multiple times with growing context.
"""
by_trace: dict[str, list[dict]] = {}
for span in spans:
trace_id = span.get("context", {}).get("trace_id")
if trace_id:
by_trace.setdefault(trace_id, []).append(span)
return [
max(trace_spans, key=lambda s: s.get("start_time") or "")
for trace_spans in by_trace.values()
]

def _format_payload_summary(self, payload: Any) -> str:
"""Format a payload summary for secure logging (avoid PII)."""
type_name = type(payload).__name__
Expand Down Expand Up @@ -180,6 +210,7 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]:
role = msg.get("message.role") or msg.get("role")
content = msg.get("message.content") or msg.get("content")
tool_calls = msg.get("message.tool_calls") or msg.get("tool_calls")
tool_call_id = msg.get("message.tool_call_id") or msg.get("tool_call_id")

if role:
mapped_msg = {
Expand All @@ -190,6 +221,8 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]:
}
if tool_calls:
mapped_msg["tool_calls"] = tool_calls
if tool_call_id:
mapped_msg["tool_call_id"] = tool_call_id
messages.append(mapped_msg)

# Handle Output/Completion from OpenInference
Expand Down Expand Up @@ -227,6 +260,7 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]:
role = msg.get("message.role") or msg.get("role")
content = msg.get("message.content") or msg.get("content")
tool_calls = msg.get("message.tool_calls") or msg.get("tool_calls")
tool_call_id = msg.get("message.tool_call_id") or msg.get("tool_call_id")

if role:
mapped_msg = {
Expand All @@ -237,11 +271,85 @@ def _extract_messages_from_span(self, span: dict) -> list[dict]:
}
if tool_calls:
mapped_msg["tool_calls"] = tool_calls
if tool_call_id:
mapped_msg["tool_call_id"] = tool_call_id
messages.append(mapped_msg)

if messages:
return messages

# Indexed OpenInference format from Phoenix REST API:
# llm.input_messages.{i}.message.role / .content / .tool_calls.{j}.* / .tool_call_id
input_indices: set[int] = set()
output_indices: set[int] = set()
for key in attrs:
if key.startswith("llm.input_messages."):
parts = key.split(".")
if len(parts) >= 3 and parts[2].isdigit():
input_indices.add(int(parts[2]))
elif key.startswith("llm.output_messages."):
parts = key.split(".")
if len(parts) >= 3 and parts[2].isdigit():
output_indices.add(int(parts[2]))

for i in sorted(input_indices):
role = attrs.get(f"llm.input_messages.{i}.message.role")
content = attrs.get(f"llm.input_messages.{i}.message.content")
tool_call_id = attrs.get(f"llm.input_messages.{i}.message.tool_call_id")

tc_indices: set[int] = set()
prefix = f"llm.input_messages.{i}.message.tool_calls."
for key in attrs:
if key.startswith(prefix):
parts = key[len(prefix):].split(".")
if parts and parts[0].isdigit():
tc_indices.add(int(parts[0]))
tool_calls = []
for j in sorted(tc_indices):
tc_prefix = f"llm.input_messages.{i}.message.tool_calls.{j}.tool_call."
tool_calls.append({
"tool_call.id": attrs.get(f"{tc_prefix}id", ""),
"tool_call.function.name": attrs.get(f"{tc_prefix}function.name", ""),
"tool_call.function.arguments": attrs.get(f"{tc_prefix}function.arguments", "{}"),
})

if role:
mapped_msg_in: dict = {"index": i, "type": "prompt", "role": role, "content": content}
if tool_calls:
mapped_msg_in["tool_calls"] = tool_calls
if tool_call_id:
mapped_msg_in["tool_call_id"] = tool_call_id
messages.append(mapped_msg_in)

for i in sorted(output_indices):
role = attrs.get(f"llm.output_messages.{i}.message.role")
content = attrs.get(f"llm.output_messages.{i}.message.content")

tc_indices_out: set[int] = set()
prefix_out = f"llm.output_messages.{i}.message.tool_calls."
for key in attrs:
if key.startswith(prefix_out):
parts = key[len(prefix_out):].split(".")
if parts and parts[0].isdigit():
tc_indices_out.add(int(parts[0]))
tool_calls_out = []
for j in sorted(tc_indices_out):
tc_prefix_out = f"llm.output_messages.{i}.message.tool_calls.{j}.tool_call."
tool_calls_out.append({
"tool_call.id": attrs.get(f"{tc_prefix_out}id", ""),
"tool_call.function.name": attrs.get(f"{tc_prefix_out}function.name", ""),
"tool_call.function.arguments": attrs.get(f"{tc_prefix_out}function.arguments", "{}"),
})

if role:
mapped_msg_out: dict = {"index": i, "type": "completion", "role": role, "content": content}
if tool_calls_out:
mapped_msg_out["tool_calls"] = tool_calls_out
messages.append(mapped_msg_out)

if messages:
return messages

# Fallback to GenAI semantic conventions (original code)
# Extract prompt messages
prompt_indices = set()
Expand Down Expand Up @@ -355,6 +463,115 @@ def _convert_to_openai_format(self, content: Any, role: str) -> dict:
content_text = "\n\n".join(text_parts) if text_parts else ""
return {"role": role, "content": content_text}

def _extract_tools_from_span(self, span: dict) -> list[dict] | None:
"""Extract tool definitions from a span's attributes in OpenAI tools format.

Tries three attribute conventions in order:
1. llm.invocation_parameters (LiteLLM GenAI convention) — JSON dict with a "tools" key
2. llm.tools as a JSON array where each item is {"tool.json_schema": "<json>"}
3. Indexed llm.tools.{i}.tool.json_schema keys (Phoenix REST API flat format)
"""
attrs = span.get("attributes") or {}

invocation_params = attrs.get("llm.invocation_parameters")
if invocation_params:
try:
params = json.loads(invocation_params) if isinstance(invocation_params, str) else invocation_params
if isinstance(params, dict):
tools = params.get("tools")
if isinstance(tools, list) and tools:
return tools
except (json.JSONDecodeError, Exception):
pass

tools_attr = attrs.get("llm.tools")
if tools_attr is not None:
try:
tools = json.loads(tools_attr) if isinstance(tools_attr, str) else tools_attr
if isinstance(tools, list) and tools:
# OpenInference list format: each item is {"tool.json_schema": "<json string>"}
openai_tools = []
for item in tools:
if isinstance(item, dict):
schema_str = item.get("tool.json_schema")
if schema_str is not None:
try:
schema = json.loads(schema_str) if isinstance(schema_str, str) else schema_str
openai_tools.append(schema)
continue
except (json.JSONDecodeError, Exception):
pass
openai_tools.append(item)
if openai_tools:
return openai_tools
except (json.JSONDecodeError, Exception):
pass

# Indexed flat format from Phoenix REST API: llm.tools.{i}.tool.json_schema
tool_indices: set[int] = set()
for key in attrs:
if key.startswith("llm.tools."):
parts = key.split(".")
if len(parts) >= 3 and parts[2].isdigit():
tool_indices.add(int(parts[2]))

if tool_indices:
tools = []
for i in sorted(tool_indices):
json_schema_str = attrs.get(f"llm.tools.{i}.tool.json_schema")
if json_schema_str:
try:
schema = json.loads(json_schema_str) if isinstance(json_schema_str, str) else json_schema_str
tools.append(schema)
continue
except (json.JSONDecodeError, Exception):
pass
# Fall back to building from name/description/parameters parts
name = attrs.get(f"llm.tools.{i}.tool.name")
if not name:
continue
tool: dict = {"type": "function", "function": {"name": name}}
description = attrs.get(f"llm.tools.{i}.tool.description")
if description:
tool["function"]["description"] = description
json_schema = attrs.get(f"llm.tools.{i}.tool.json_schema")
if json_schema:
try:
schema = json.loads(json_schema) if isinstance(json_schema, str) else json_schema
tool["function"]["parameters"] = schema
except (json.JSONDecodeError, Exception):
pass
tools.append(tool)
if tools:
return tools

return None

def _convert_openinference_tool_calls(self, tool_calls: list) -> list:
"""Convert OpenInference tool_calls to OpenAI format.

OpenInference: {"tool_call.function.name": ..., "tool_call.id": ..., "tool_call.function.arguments": ...}
OpenAI: {"id": ..., "type": "function", "function": {"name": ..., "arguments": ...}}
"""
result = []
for tc in tool_calls:
if not isinstance(tc, dict):
continue
name = tc.get("tool_call.function.name")
if name is not None:
arguments = tc.get("tool_call.function.arguments", "{}")
result.append({
"id": tc.get("tool_call.id", ""),
"type": "function",
"function": {
"name": name,
"arguments": arguments if isinstance(arguments, str) else json.dumps(arguments),
},
})
elif "id" in tc or "function" in tc:
result.append(tc)
return result

def _extract_trajectory(self, span: dict) -> dict:
"""Extract a complete trajectory from a span."""
attrs = span.get("attributes") or {}
Expand All @@ -365,9 +582,12 @@ def _extract_trajectory(self, span: dict) -> dict:
for msg in messages:
role = msg["role"]
content = msg["content"]
raw_tool_calls = msg.get("tool_calls")
tool_call_id = msg.get("tool_call_id")
converted = self._convert_to_openai_format(content, role)

if converted.get("role") == "tool" and "tool_results" in converted:
# Anthropic content-block format — tool_call_id already embedded
for result in converted["tool_results"]:
openai_messages.append(
{
Expand All @@ -376,15 +596,33 @@ def _extract_trajectory(self, span: dict) -> dict:
"content": result["content"],
}
)
elif role == "tool" and tool_call_id:
# OpenInference format — tool_call_id extracted from span attribute
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": converted.get("content", ""),
}
)
elif role == "assistant" and raw_tool_calls:
# OpenInference format — convert tool_calls from OpenInference to OpenAI
openai_tool_calls = self._convert_openinference_tool_calls(raw_tool_calls)
if openai_tool_calls:
converted["tool_calls"] = openai_tool_calls
if converted.get("content") in (None, "None", ""):
converted.pop("content", None)
openai_messages.append(converted)
else:
openai_messages.append(converted)

return {
"trace_id": span["context"]["trace_id"],
"span_id": span["context"]["span_id"],
"model": attrs.get("gen_ai.request.model", "unknown"),
"model": attrs.get("gen_ai.request.model") or attrs.get("llm.model_name", "unknown"),
"timestamp": span.get("start_time"),
"messages": openai_messages,
"tools": self._extract_tools_from_span(span),
"usage": {
"prompt_tokens": next(
(
Expand Down Expand Up @@ -532,38 +770,50 @@ def sync(
spans = self._fetch_spans(limit)
logger.info(f"Fetched {len(spans)} spans from Phoenix")

# Get already processed span IDs
processed_ids = self._get_processed_span_ids()
logger.info(f"Found {len(processed_ids)} already processed spans")
# Dedup at trace level: each agent run produces multiple spans (one per LLM call),
# each accumulating the full message history. Processing all would analyse the same
# steps multiple times. We skip any trace already stored, then pick one
# representative span per remaining trace (the latest, which has the complete context).
processed_trace_ids = self._get_processed_trace_ids()
logger.info(f"Found {len(processed_trace_ids)} already processed traces")

processed = 0
skipped = 0
guidelines_generated = 0
errors = []

# First pass: filter to candidate LLM spans, skipping already-processed traces
candidates = []
skipped_trace_ids: set[str] = set()
for span in spans:
# Filter to LLM request spans - accept any span with prompt attributes
# if span.get("name") != "litellm_request":
# continue

# Filter errors if requested
if not include_errors and span.get("status_code") == "ERROR":
continue

# Check if already processed
span_id = span.get("context", {}).get("span_id")
if span_id in processed_ids:
skipped += 1
continue

# Only include spans with actual messages or GenAI/LLM prompt attributes
attrs = span.get("attributes") or {}
has_gen_ai = any(k.startswith("gen_ai.prompt.") for k in attrs)
has_llm_msgs = "llm.input_messages" in attrs or "input.value" in attrs

has_llm_msgs = (
"llm.input_messages" in attrs
or "input.value" in attrs
or any(k.startswith("llm.input_messages.") for k in attrs)
)
if not (has_gen_ai or has_llm_msgs):
continue

trace_id = span.get("context", {}).get("trace_id")
if trace_id in processed_trace_ids:
skipped_trace_ids.add(trace_id)
continue

candidates.append(span)

skipped = len(skipped_trace_ids)

# Second pass: one representative span per trace (latest start_time = most complete context)
representative_spans = self._select_representative_spans(candidates)
logger.info(f"Selected {len(representative_spans)} representative spans from {len(candidates)} candidates")

for span in representative_spans:
span_id = span.get("context", {}).get("span_id")
try:
trajectory = self._extract_trajectory(span)
trajectory = self._clean_trajectory(trajectory)
Expand All @@ -572,7 +822,7 @@ def sync(
guidelines_count = self._process_trajectory(trajectory)
processed += 1
guidelines_generated += guidelines_count
logger.info(f"Processed span {span_id[:12]}... - generated {guidelines_count} guidelines")
logger.info(f"Processed trace {trajectory['trace_id'][:12]}... - generated {guidelines_count} guidelines")
except Exception as e:
error_msg = f"Error processing span {span_id}: {e}"
logger.exception(error_msg)
Expand Down
Loading
Loading