diff --git a/app/modules/proxy/http_bridge_forwarding.py b/app/modules/proxy/http_bridge_forwarding.py index db5b16aa5..d0d2174ea 100644 --- a/app/modules/proxy/http_bridge_forwarding.py +++ b/app/modules/proxy/http_bridge_forwarding.py @@ -11,7 +11,7 @@ import aiohttp -from app.core.clients.proxy import ProxyResponseError +from app.core.clients.proxy import ProxyResponseError, filter_inbound_headers from app.core.config.settings import get_settings from app.core.crypto import get_or_create_key from app.core.errors import OpenAIErrorEnvelope, openai_error, response_failed_event @@ -22,6 +22,26 @@ from app.modules.api_keys.service import ApiKeyUsageReservationData from app.modules.proxy._service.http_bridge.helpers import _http_bridge_request_budget_seconds +# HTTP-only and hop-by-hop headers that must not be forwarded through the +# internal bridge. These headers are either illegal in WebSocket handshakes or +# carry HTTP framing semantics that the aiohttp upstream session manages itself. +# Applies on top of filter_inbound_headers (which already strips authorization, +# host, content-length, and x-forwarded-* / cf-* headers). +_BRIDGE_UNSAFE_HEADER_NAMES = frozenset( + { + "accept", + "accept-encoding", + "connection", + "content-type", + "cookie", + "keep-alive", + "te", + "trailer", + "transfer-encoding", + "upgrade", + } +) + HTTP_BRIDGE_INTERNAL_FORWARD_PATH = "/internal/bridge/responses" HTTP_BRIDGE_FORWARDED_HEADER = "x-codex-bridge-forwarded" HTTP_BRIDGE_ORIGIN_INSTANCE_HEADER = "x-codex-bridge-origin-instance" @@ -136,9 +156,27 @@ def build_owner_forward_headers( payload: ResponsesRequest, context: HTTPBridgeForwardContext, ) -> dict[str, str]: - forwarded = dict(headers) - forwarded.pop("host", None) - forwarded.pop("content-length", None) + filtered = filter_inbound_headers(headers) + # Per the hop-by-hop contract, also drop any header named by the inbound + # Connection header in addition to the fixed unsafe set. + connection_value = next( + (value for key, value in headers.items() if key.lower() == "connection"), + "", + ) + connection_named = {token.strip().lower() for token in connection_value.split(",") if token.strip()} + drop = _BRIDGE_UNSAFE_HEADER_NAMES | connection_named + forwarded = {key: value for key, value in filtered.items() if key.lower() not in drop} + # filter_inbound_headers strips Authorization, but the owner instance + # re-validates the client API key from this header (see + # _validate_internal_bridge_api_key) before swapping in its own upstream + # access token. Preserve it so api_key_auth_enabled deployments still + # authenticate forwarded bridge requests. + authorization = next( + (value for key, value in headers.items() if key.lower() == "authorization"), + None, + ) + if authorization is not None: + forwarded["authorization"] = authorization forwarded[HTTP_BRIDGE_FORWARDED_HEADER] = "1" forwarded[HTTP_BRIDGE_ORIGIN_INSTANCE_HEADER] = context.origin_instance forwarded[HTTP_BRIDGE_TARGET_INSTANCE_HEADER] = context.target_instance diff --git a/tests/unit/test_http_bridge_forwarding.py b/tests/unit/test_http_bridge_forwarding.py index e5b247bd8..538da2f76 100644 --- a/tests/unit/test_http_bridge_forwarding.py +++ b/tests/unit/test_http_bridge_forwarding.py @@ -338,3 +338,87 @@ def post(self, url: str, **kwargs: object) -> FakeResponse: assert '"type":"response.failed"' in events[0] assert '"code":"stream_incomplete"' in events[0] assert captured["trust_env"] is False + + +def test_build_owner_forward_headers_strips_hop_by_hop_headers() -> None: + payload = _payload() + context = HTTPBridgeForwardContext( + origin_instance="instance-a", + target_instance="instance-b", + codex_session_affinity=False, + downstream_turn_state=None, + ) + inbound = { + "Accept": "application/json", + "Accept-Encoding": "gzip, deflate", + "Connection": "keep-alive", + "Content-Type": "application/json", + "Cookie": "session=abc", + "x-request-id": "req-123", + } + + headers = build_owner_forward_headers(headers=inbound, payload=payload, context=context) + + assert "Accept" not in headers + assert "accept" not in headers + assert "Accept-Encoding" not in headers + assert "accept-encoding" not in headers + assert "Connection" not in headers + assert "connection" not in headers + assert "Content-Type" not in headers + assert "content-type" not in headers + assert "Cookie" not in headers + assert "cookie" not in headers + assert headers.get("x-request-id") == "req-123" + assert HTTP_BRIDGE_FORWARDED_HEADER in headers + assert HTTP_BRIDGE_TARGET_INSTANCE_HEADER in headers + + +def test_build_owner_forward_headers_preserves_authorization_strips_host() -> None: + payload = _payload() + context = HTTPBridgeForwardContext( + origin_instance="instance-a", + target_instance="instance-b", + codex_session_affinity=False, + downstream_turn_state=None, + ) + inbound = { + "Authorization": "Bearer downstream-key", + "Host": "client.example.com", + "content-length": "42", + "x-openai-client-version": "1.2.3", + } + + headers = build_owner_forward_headers(headers=inbound, payload=payload, context=context) + + # The owner instance re-validates the client API key from Authorization + # (see _validate_internal_bridge_api_key) before swapping in its own + # upstream token, so the header must survive the forward. + assert headers.get("authorization") == "Bearer downstream-key" + assert "Host" not in headers + assert "host" not in headers + assert "content-length" not in headers + assert headers.get("x-openai-client-version") == "1.2.3" + + +def test_build_owner_forward_headers_drops_connection_named_headers() -> None: + payload = _payload() + context = HTTPBridgeForwardContext( + origin_instance="instance-a", + target_instance="instance-b", + codex_session_affinity=False, + downstream_turn_state=None, + ) + inbound = { + "Connection": "keep-alive, X-Custom-Hop", + "X-Custom-Hop": "drop-me", + "x-request-id": "req-123", + } + + headers = build_owner_forward_headers(headers=inbound, payload=payload, context=context) + + assert "X-Custom-Hop" not in headers + assert "x-custom-hop" not in headers + assert "Connection" not in headers + assert "connection" not in headers + assert headers.get("x-request-id") == "req-123"