Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from agents.items import TResponseStreamEvent
from agents.tool import (
ApplyPatchTool,
CustomTool,
LocalShellTool,
ShellTool,
ShellToolEnvironment,
Expand All @@ -39,6 +40,7 @@
APIStatusError,
AsyncOpenAI,
)
from openai.types.responses import CustomToolParam
from openai.types.responses.tool_param import Mcp
from typing_extensions import Required, TypedDict

Expand Down Expand Up @@ -112,6 +114,15 @@ class ApplyPatchToolInput:
name: str = "apply_patch"


@dataclass
class CustomToolInput:
"""Data conversion friendly representation of a CustomTool. Contains only the fields which are needed by the model
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
"""

tool_config: CustomToolParam


ToolInput = (
FunctionToolInput
| FileSearchTool
Expand All @@ -122,6 +133,7 @@ class ApplyPatchToolInput:
| ShellToolInput
| LocalShellTool
| ApplyPatchToolInput
| CustomToolInput
| ToolSearchTool
)

Expand Down Expand Up @@ -235,6 +247,14 @@ def _build_tool(tool: ToolInput) -> Tool:
return ApplyPatchTool(name=tool.name, editor=_NoopApplyPatchEditor())
elif isinstance(tool, HostedMCPToolInput):
return HostedMCPTool(tool_config=tool.tool_config)
elif isinstance(tool, CustomToolInput):
return CustomTool(
name=tool.tool_config["name"],
description=tool.tool_config.get("description", ""),
on_invoke_tool=_empty_on_invoke_tool,
format=tool.tool_config.get("format"),
defer_loading=tool.tool_config.get("defer_loading", False),
)
elif isinstance(tool, FunctionToolInput):
return FunctionTool(
name=tool.name,
Expand Down
11 changes: 10 additions & 1 deletion temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,21 @@
WebSearchTool,
)
from agents.items import TResponseStreamEvent
from agents.tool import ApplyPatchTool, LocalShellTool, ShellTool, ToolSearchTool
from agents.tool import (
ApplyPatchTool,
CustomTool,
LocalShellTool,
ShellTool,
ToolSearchTool,
)
from openai.types.responses.response_prompt_param import ResponsePromptParam

from temporalio import workflow
from temporalio.contrib.openai_agents._invoke_model_activity import (
ActivityModelInput,
AgentOutputSchemaInput,
ApplyPatchToolInput,
CustomToolInput,
FunctionToolInput,
HandoffInput,
HostedMCPToolInput,
Expand Down Expand Up @@ -92,6 +99,8 @@ def make_tool_info(tool: Tool) -> ToolInput:
return ApplyPatchToolInput(name=tool.name)
elif isinstance(tool, HostedMCPTool):
return HostedMCPToolInput(tool_config=tool.tool_config)
elif isinstance(tool, CustomTool):
return CustomToolInput(tool_config=tool.tool_config)
elif isinstance(tool, FunctionTool):
return FunctionToolInput(
name=tool.name,
Expand Down
86 changes: 86 additions & 0 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2538,6 +2538,92 @@ async def test_model_conversion_loops():
assert isinstance(triage_agent.model, _TemporalModelStub)


def test_sandbox_apply_patch_tool_round_trips_through_activity_input():
from agents.sandbox.capabilities.tools import SandboxApplyPatchTool
from agents.tool import CustomTool

from temporalio.contrib.openai_agents._invoke_model_activity import (
_build_tool,
)
Comment thread
xumaple marked this conversation as resolved.
Outdated

class FakeSandboxSession:
pass

tool = SandboxApplyPatchTool(session=FakeSandboxSession()) # type: ignore[arg-type]

stub = _TemporalModelStub(
model_name="gpt-5",
model_params=ModelActivityParameters(),
agent=None,
)

activity_input, _summary = stub._build_activity_input(
system_instructions=None,
input="hi",
model_settings=ModelSettings(),
tools=[tool],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
conversation_id=None,
prompt=None,
)

tool_inputs = activity_input.get("tools") or []
assert len(tool_inputs) == 1
rebuilt = _build_tool(tool_inputs[0])
assert isinstance(rebuilt, CustomTool)
assert rebuilt.name == tool.name
assert rebuilt.description == tool.description
assert rebuilt.format == tool.format
assert rebuilt.tool_config == tool.tool_config


def test_custom_tool_with_defer_loading_round_trips_through_activity_input():
from agents.tool import CustomTool

from temporalio.contrib.openai_agents._invoke_model_activity import (
_build_tool,
)

async def stub(_ctx: Any, _payload: str) -> str:
return ""

tool = CustomTool(
name="deferred_tool",
description="A custom tool with defer_loading enabled",
on_invoke_tool=stub,
defer_loading=True,
)

stub_model = _TemporalModelStub(
model_name="gpt-5",
model_params=ModelActivityParameters(),
agent=None,
)

activity_input, _summary = stub_model._build_activity_input(
system_instructions=None,
input="hi",
model_settings=ModelSettings(),
tools=[tool],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
conversation_id=None,
prompt=None,
)

tool_inputs = activity_input.get("tools") or []
assert len(tool_inputs) == 1
rebuilt = _build_tool(tool_inputs[0])
assert isinstance(rebuilt, CustomTool)
assert rebuilt.tool_config == tool.tool_config
assert rebuilt.defer_loading is True


Comment thread
xumaple marked this conversation as resolved.
async def test_local_hello_world_agent(client: Client):
async with AgentEnvironment(
model=hello_mock_model(),
Expand Down
Loading