diff --git a/app/modules/proxy/_service/websocket/helpers.py b/app/modules/proxy/_service/websocket/helpers.py index b8888439b..3cf601fcd 100644 --- a/app/modules/proxy/_service/websocket/helpers.py +++ b/app/modules/proxy/_service/websocket/helpers.py @@ -396,6 +396,34 @@ def _websocket_continuity_anchor_for_payload( ) +_WEBSOCKET_TOOL_CALL_ITEM_TYPES_BY_OUTPUT_TYPE = { + "function_call_output": "function_call", + "custom_tool_call_output": "custom_tool_call", + "apply_patch_call_output": "apply_patch_call", +} +_WEBSOCKET_TOOL_CALL_ITEM_TYPES = frozenset(_WEBSOCKET_TOOL_CALL_ITEM_TYPES_BY_OUTPUT_TYPE.values()) + + +def _websocket_input_items_are_self_contained_fresh_replay(input_items: list[JsonValue]) -> bool: + seen_call_ids_by_type: dict[str, set[str]] = {item_type: set() for item_type in _WEBSOCKET_TOOL_CALL_ITEM_TYPES} + for item in input_items: + if not isinstance(item, dict): + continue + item_type = _websocket_input_item_type(item) + call_id_value = item.get("call_id") + call_id = call_id_value if isinstance(call_id_value, str) and call_id_value else None + if item_type in _WEBSOCKET_TOOL_CALL_ITEM_TYPES: + if call_id is not None: + seen_call_ids_by_type[item_type].add(call_id) + continue + call_item_type = _WEBSOCKET_TOOL_CALL_ITEM_TYPES_BY_OUTPUT_TYPE.get(item_type or "") + if call_item_type is None: + continue + if call_id is None or call_id not in seen_call_ids_by_type[call_item_type]: + return False + return True + + def _websocket_client_previous_response_full_resend_is_retry_safe( *, previous_response_id: str | None, @@ -407,6 +435,8 @@ def _websocket_client_previous_response_full_resend_is_retry_safe( input_items = cast(list[JsonValue], input_value) if len(input_items) <= 1: return False + if not _websocket_input_items_are_self_contained_fresh_replay(input_items): + return False if ( continuity_state is not None and continuity_state.last_completed_response_id == previous_response_id diff --git a/openspec/specs/responses-api-compat/spec.md b/openspec/specs/responses-api-compat/spec.md index 82a72059f..795326340 100644 --- a/openspec/specs/responses-api-compat/spec.md +++ b/openspec/specs/responses-api-compat/spec.md @@ -103,7 +103,7 @@ When an HTTP bridge session receives an anonymous upstream `previous_response_no - **AND** the downstream error code is not `previous_response_not_found` ### Requirement: WebSocket full-resend previous-response misses retry without stale anchor -When a direct WebSocket `response.create` request includes both `previous_response_id` and a full resend payload, the service MUST retain a safe replay body without `previous_response_id`. If upstream rejects the anchor with `previous_response_not_found` before `response.created`, the service MUST reconnect and replay the retained full payload as a fresh turn instead of forwarding the raw upstream invalid-request error. +When a direct WebSocket `response.create` request includes both `previous_response_id` and a self-contained full resend payload, the service MUST retain a safe replay body without `previous_response_id`. If upstream rejects the anchor with `previous_response_not_found` before `response.created`, the service MUST reconnect and replay the retained full payload as a fresh turn instead of forwarding the raw upstream invalid-request error. A payload that only carries incremental tool outputs for tool calls that are not also present in the same request is not self-contained and MUST NOT be replayed as a fresh turn without `previous_response_id`. #### Scenario: full-resend WebSocket follow-up loses just-completed anchor - **WHEN** a WebSocket `/v1/responses` or `/backend-api/codex/responses` follow-up has `previous_response_id` @@ -113,6 +113,13 @@ When a direct WebSocket `response.create` request includes both `previous_respon - **AND** it replays the same request without `previous_response_id` - **AND** the downstream client receives the recovered response events, not the raw `previous_response_not_found` error +#### Scenario: output-only WebSocket tool delta is not replayed as a fresh turn +- **WHEN** a WebSocket `/v1/responses` or `/backend-api/codex/responses` follow-up has `previous_response_id` +- **AND** the request payload carries `function_call_output`, `custom_tool_call_output`, or `apply_patch_call_output` items without their matching tool-call items in the same payload +- **AND** upstream emits `previous_response_not_found` before assigning a response id +- **THEN** the service MUST NOT replay that payload as a fresh turn without `previous_response_id` +- **AND** the downstream client receives a retryable continuity failure rather than a fabricated fresh turn + ### Requirement: Public Responses errors mask previous-response misses Public Responses endpoints MUST NOT return an OpenAI-shaped `previous_response_not_found` error to clients. If a lower layer still raises or collects that error, the API layer MUST rewrite it to a retryable `stream_incomplete` continuity failure and remove the missing response id from the public payload. diff --git a/tests/unit/test_proxy_utils.py b/tests/unit/test_proxy_utils.py index d0f25fb55..a8427ad46 100644 --- a/tests/unit/test_proxy_utils.py +++ b/tests/unit/test_proxy_utils.py @@ -10179,6 +10179,69 @@ class Settings: assert fresh_payload["input"] == full_resend_input +@pytest.mark.asyncio +async def test_prepare_websocket_response_create_request_does_not_fresh_retry_tool_output_delta(monkeypatch): + request_logs = _RequestLogsRecorder() + service = proxy_service.ProxyService(_repo_factory(request_logs)) + reserve_usage = AsyncMock(return_value=None) + api_key = ApiKeyData( + id="key_ws_tool_delta", + name="ws-tool-delta", + key_prefix="sk-ws-tool-delta", + allowed_models=["gpt-5.1"], + enforced_model=None, + enforced_reasoning_effort=None, + enforced_service_tier=None, + expires_at=None, + is_active=True, + created_at=utcnow(), + last_used_at=None, + ) + + class Settings: + log_proxy_request_payload = False + log_proxy_request_shape = False + log_proxy_request_shape_raw_cache_key = False + log_proxy_service_tier_trace = False + openai_prompt_cache_key_derivation_enabled = True + + tool_output_delta: list[JsonValue] = [ + {"type": "function_call_output", "call_id": "call_delta_a", "output": "ok"}, + {"type": "function_call_output", "call_id": "call_delta_b", "output": "ok"}, + {"type": "function_call_output", "call_id": "call_delta_c", "output": "ok"}, + ] + + monkeypatch.setattr(proxy_service, "get_settings", lambda: Settings()) + monkeypatch.setattr(service, "_reserve_websocket_api_key_usage", reserve_usage) + monkeypatch.setattr(service, "_refresh_websocket_api_key_policy", AsyncMock(return_value=api_key)) + + prepared = await service._prepare_websocket_response_create_request( + cast( + dict[str, JsonValue], + { + "type": "response.create", + "model": "gpt-5.1", + "previous_response_id": "resp_client_anchor", + "input": tool_output_delta, + }, + ), + headers={"session_id": "turn_ws_tool_delta"}, + codex_session_affinity=True, + openai_cache_affinity=True, + sticky_threads_enabled=False, + openai_cache_affinity_max_age_seconds=300, + api_key=api_key, + continuity_state=None, + ) + + upstream_payload = json.loads(prepared.text_data) + assert upstream_payload["previous_response_id"] == "resp_client_anchor" + assert upstream_payload["input"] == tool_output_delta + assert prepared.request_state.previous_response_id == "resp_client_anchor" + assert prepared.request_state.fresh_upstream_request_is_retry_safe is False + assert prepared.request_state.fresh_upstream_request_text is None + + def test_websocket_client_previous_response_full_resend_retry_requires_matching_prefix() -> None: stored_prefix: list[JsonValue] = [{"role": "user", "content": [{"type": "input_text", "text": "old question"}]}] continuity_state = proxy_service._WebSocketContinuityState( @@ -10202,6 +10265,56 @@ def test_websocket_client_previous_response_full_resend_retry_requires_matching_ ) +def test_websocket_client_previous_response_full_resend_retry_rejects_tool_output_delta() -> None: + tool_output_delta: list[JsonValue] = [ + {"type": "function_call_output", "call_id": "call_a", "output": "ok"}, + {"type": "function_call_output", "call_id": "call_b", "output": "ok"}, + ] + + assert ( + proxy_service._websocket_client_previous_response_full_resend_is_retry_safe( + previous_response_id="resp_client_anchor", + input_value=tool_output_delta, + continuity_state=None, + ) + is False + ) + + +def test_websocket_client_previous_response_full_resend_retry_rejects_output_before_call() -> None: + reordered_tool_history: list[JsonValue] = [ + {"type": "function_call_output", "call_id": "call_late", "output": "ok"}, + {"type": "function_call", "name": "shell_command", "call_id": "call_late", "arguments": "{}"}, + ] + + assert ( + proxy_service._websocket_client_previous_response_full_resend_is_retry_safe( + previous_response_id="resp_client_anchor", + input_value=reordered_tool_history, + continuity_state=None, + ) + is False + ) + + +def test_websocket_client_previous_response_full_resend_retry_allows_self_contained_tool_history() -> None: + self_contained_tool_history: list[JsonValue] = [ + {"role": "user", "content": [{"type": "input_text", "text": "run a command"}]}, + {"type": "function_call", "name": "shell_command", "call_id": "call_ok", "arguments": "{}"}, + {"type": "function_call_output", "call_id": "call_ok", "output": "ok"}, + {"role": "user", "content": [{"type": "input_text", "text": "continue"}]}, + ] + + assert ( + proxy_service._websocket_client_previous_response_full_resend_is_retry_safe( + previous_response_id="resp_client_anchor", + input_value=self_contained_tool_history, + continuity_state=None, + ) + is True + ) + + @pytest.mark.asyncio async def test_prepare_websocket_response_create_request_fills_interrupted_pending_tool_outputs(monkeypatch): request_logs = _RequestLogsRecorder()