Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/gaia/agents/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,14 @@ def _get_mixin_prompts(self) -> list[str]:
fragment = getattr(self, attr_name)()
if fragment:
prompts.append(fragment)
except Exception:
pass
except Exception as e:
# A raising fragment is dropped from the composed prompt; surface it
# so a silently degraded system prompt is diagnosable.
logger.warning(
"system-prompt fragment %s() raised, skipping it: %s",
attr_name,
e,
)

return prompts

Expand Down
6 changes: 4 additions & 2 deletions src/gaia/agents/base/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def _load_memory_settings() -> Dict:
try:
if _MEMORY_SETTINGS_PATH.exists():
return json.loads(_MEMORY_SETTINGS_PATH.read_text(encoding="utf-8"))
except Exception:
pass
except (OSError, json.JSONDecodeError) as e:
logger.warning(
"failed to load memory settings from %s: %s", _MEMORY_SETTINGS_PATH, e
)
return {}


Expand Down
4 changes: 2 additions & 2 deletions src/gaia/mcp/mcp_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def _parse_cd(self, value: str):
name = p.split("=", 1)[1].strip().strip('"')
elif pl.startswith("filename="):
filename = p.split("=", 1)[1].strip().strip('"')
except Exception:
pass
except (AttributeError, IndexError, ValueError) as e:
logger.debug("Failed to parse Content-Disposition %r: %s", value, e)
return name, filename

def on_part_begin(self):
Expand Down
7 changes: 6 additions & 1 deletion src/gaia/messaging/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,12 @@ def start(self, token: str, background: bool = False) -> None:
with open(pid_path, "w", encoding="utf-8") as f:
f.write(str(os.getpid()))
except OSError as e:
log.exception("Failed to write PID file for telegram adapter: %s", e)
# The PID file is load-bearing in background mode: without it a
# supervisor cannot find or kill the process. Fail loudly.
raise RuntimeError(
f"Failed to write telegram PID file at {pid_path}: {e}. "
f"Ensure ~/.gaia is writable, or start without --background."
) from e

log_path = os.path.join(pid_dir, "telegram.log")
try:
Expand Down
4 changes: 2 additions & 2 deletions src/gaia/rag/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,8 @@ def _load_embedder(self):
# Force fresh load - must unload first
try:
self.llm_client.unload_model()
except Exception:
pass # Ignore if nothing to unload
except Exception as e:
self.log.warning("unload_model failed (continuing): %s", e)

try:
self.llm_client.load_model(
Expand Down
90 changes: 82 additions & 8 deletions src/gaia/web/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,42 @@
ALLOWED_SCHEMES = {"http", "https"}
BLOCKED_PORTS = {22, 23, 25, 445, 3306, 5432, 6379, 27017}


def _is_blocked_ip(ip: "ipaddress._BaseAddress") -> bool:
"""Return True if ``ip`` points at a private/internal range we must not fetch."""
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_reserved
or ip.is_multicast
)


def _assert_ip_allowed(ip_str: str, hostname: str) -> None:
"""Raise ValueError if ``ip_str`` is a private/reserved address.

The single authority for "is this IP safe to connect to". Both
``WebClient.validate_url`` (pre-flight DNS check) and
``PinnedIPAdapter`` (the IP it actually connects to) route through this,
so a DNS rebind that slips a private IP past the pre-flight lookup is
still caught at connect time on the *exact* address being dialed.
"""
try:
ip = ipaddress.ip_address(ip_str)
except ValueError:
# Not parseable as an IP — treat as unsafe rather than letting an
# unvalidated value reach the socket layer (fail loudly).
raise ValueError(
f"Blocked: {hostname} resolved to unparseable address {ip_str!r}."
)
if _is_blocked_ip(ip):
raise ValueError(
f"Blocked: {hostname} resolves to private/reserved IP {ip}. "
"Cannot fetch internal network addresses."
)


# Tags to remove during text extraction
REMOVE_TAGS = [
"script",
Expand Down Expand Up @@ -72,12 +108,35 @@ class PinnedIPAdapter(HTTPAdapter):
DNS-rebind attacks between ``WebClient.validate_url`` and the actual
TCP connect.

Crucially, the pinned IP is itself validated (``_assert_ip_allowed``)
before it is cached or connected to. ``validate_url`` runs a *separate*
pre-flight ``getaddrinfo``; an attacker controlling DNS could answer that
lookup with a public IP and answer the adapter's lookup with a private
one. Validating the exact address the adapter is about to dial closes
that residual rebind window for BOTH http and https.

For HTTPS, the original hostname is encoded in the URL's userinfo
section (``originalhostname@pinnedip:port``) so that urllib3 creates
separate connection-pool keys per original hostname. This avoids a
race where two threads requesting different hostnames that resolve to
the same IP would overwrite each other's ``assert_hostname`` on a
shared pool.

Residual HTTPS limitation (documented, not silently ignored):
Because ``requests`` derives the urllib3 pool host — and therefore the
TLS SNI ``server_hostname`` — from the request URL's hostname (which we
rewrote to the pinned IP), the ClientHello SNI is sent as the IP, not the
original hostname. ``assert_hostname`` still forces certificate-name
verification against the real hostname (so verification is NOT disabled
and we never trust a cert for the bare IP), but servers that rely on SNI
for virtual hosting (most CDNs / shared hosts) may return the wrong
certificate or reject the handshake, surfacing as a TLS error rather than
a silent downgrade. This affects whether legitimate HTTPS *succeeds* — it
does not weaken the SSRF block, which fires on the validated IP before any
bytes are sent. Fixing SNI cleanly requires a custom urllib3
``PoolManager``/``HTTPSConnection`` that decouples ``server_hostname``
from the connect address; that is intentionally out of scope here in
favour of a correct, narrower guarantee.
"""

def __init__(self, *args, **kwargs):
Expand All @@ -94,6 +153,12 @@ def _resolve_first_ip(self, host: str, port: int) -> str:
raise OSError(f"getaddrinfo returned no addresses for {host}:{port}")

ip = infos[0][4][0] # sockaddr[0] of the first result
# Validate the EXACT IP we are about to pin & connect to. validate_url
# did a pre-flight getaddrinfo, but that was a separate lookup — a DNS
# rebind could hand validate_url a public IP and hand us a private one.
# Re-check here so the address actually dialed is always safe; cache
# only after it passes so a poisoned answer is never reused.
_assert_ip_allowed(ip, host)
self._pinned_cache[key] = ip
return ip

Expand Down Expand Up @@ -134,7 +199,22 @@ def send(self, request: requests.PreparedRequest, **kwargs) -> requests.Response
pinned_ip = self._resolve_first_ip(host, port)

if parsed.scheme == "https":
# Encode original hostname in userinfo for unique pool keys
# Encode original hostname in userinfo for unique pool keys.
# See class docstring: SNI is sent as the pinned IP, so
# SNI-vhosted servers may fail the handshake. Cert-name
# verification still binds to the real hostname.
if not getattr(self, "_warned_https_sni", False):
log.warning(
"PinnedIPAdapter: HTTPS request to %s is pinned to %s; "
"TLS SNI will be sent as the IP. Servers using "
"SNI-based virtual hosting may return the wrong "
"certificate or reject the handshake. Certificate-name "
"verification still validates against %s.",
host,
pinned_ip,
host,
)
self._warned_https_sni = True
new_netloc = f"{host}@{pinned_ip}:{port}"
else:
new_netloc = f"{pinned_ip}:{port}"
Expand Down Expand Up @@ -270,13 +350,7 @@ def _validate_host_ip(self, hostname: str) -> None:
except ValueError:
continue

if (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_reserved
or ip.is_multicast
):
if _is_blocked_ip(ip):
raise ValueError(
f"Blocked: {hostname} resolves to private/reserved IP {ip}. "
"Cannot fetch internal network addresses."
Expand Down
130 changes: 89 additions & 41 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,95 @@ def setup(self):
pytest.skip(f"API dependencies not available: {IMPORT_ERROR}")
self.client = TestClient(app)

# -------------------------------------------------------------------------
# Non-streaming happy path (mocked agent backend — no Lemonade required)
# -------------------------------------------------------------------------

def test_basic_completion_with_mocked_agent(self, mocker):
"""Non-streaming POST returns a schema-valid OpenAI completion.

The agent/Lemonade backend is mocked: registry.get_agent yields a stub
whose process_query returns a canned result dict, so the handler's
non-streaming branch runs end-to-end without a live LLM server.
"""
# Stub agent: NOT an ApiAgent, so the handler uses the len//4 token
# estimate path (deterministic, no tokenizer needed).
fake_agent = mocker.MagicMock()
fake_agent.process_query.return_value = {
"status": "success",
"result": "def hello():\n return 'hello world'",
}

from gaia.api.openai_server import registry as server_registry

mocker.patch.object(server_registry, "get_agent", return_value=fake_agent)

payload = {
"model": "gaia-code",
"messages": [
{"role": "user", "content": "Write a hello world function in Python"}
],
"stream": False,
}
response = self.client.post("/v1/chat/completions", json=payload)

assert response.status_code == 200, response.text
data = response.json()

# Top-level OpenAI-compatible structure.
assert data["object"] == "chat.completion"
assert data["id"].startswith("chatcmpl-")
assert isinstance(data["created"], int)
assert data["model"] == "gaia-code"

# The agent was invoked with the extracted user message.
fake_agent.process_query.assert_called_once()
call_args, _ = fake_agent.process_query.call_args
assert call_args[0] == "Write a hello world function in Python"

# Choices.
assert len(data["choices"]) == 1
choice = data["choices"][0]
assert choice["index"] == 0
assert choice["message"]["role"] == "assistant"
assert choice["message"]["content"] == (
"def hello():\n return 'hello world'"
)
assert choice["finish_reason"] == "stop"

# Usage accounting.
usage = data["usage"]
assert usage["prompt_tokens"] > 0
assert usage["completion_tokens"] > 0
assert usage["total_tokens"] == (
usage["prompt_tokens"] + usage["completion_tokens"]
)

def test_completion_uses_last_user_message(self, mocker):
"""The handler passes the LAST user message (not system/assistant) to the agent."""
fake_agent = mocker.MagicMock()
fake_agent.process_query.return_value = {"result": "ok"}

from gaia.api.openai_server import registry as server_registry

mocker.patch.object(server_registry, "get_agent", return_value=fake_agent)

payload = {
"model": "gaia-code",
"messages": [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "first question"},
{"role": "assistant", "content": "an earlier answer"},
{"role": "user", "content": "second question"},
],
"stream": False,
}
response = self.client.post("/v1/chat/completions", json=payload)

assert response.status_code == 200, response.text
call_args, _ = fake_agent.process_query.call_args
assert call_args[0] == "second question"

# -------------------------------------------------------------------------
# Model Validation Tests
# -------------------------------------------------------------------------
Expand Down Expand Up @@ -308,47 +397,6 @@ def test_422_error_has_detail_field(self):
class TestChatCompletionsNonStreaming:
"""Test POST /v1/chat/completions without streaming"""

@pytest.mark.skip(reason="Skipped: API server returns 500 - see issue for fix")
def test_basic_completion_with_code_agent(self, api_server, api_client):
"""Test that gaia-code returns valid OpenAI-compatible completion"""
payload = {
"model": "gaia-code",
"messages": [
{"role": "user", "content": "Write a hello world function in Python"}
],
"stream": False,
}
response = api_client.post(f"{api_server}/v1/chat/completions", json=payload)

assert response.status_code == 200
data = response.json()

# Verify OpenAI-compatible structure
assert data["object"] == "chat.completion"
assert "id" in data
assert data["id"].startswith("chatcmpl-")
assert "created" in data
assert isinstance(data["created"], int)
assert data["model"] == "gaia-code"

# Verify choices
assert "choices" in data
assert len(data["choices"]) == 1
choice = data["choices"][0]
assert choice["index"] == 0
assert choice["message"]["role"] == "assistant"
assert isinstance(choice["message"]["content"], str)
assert len(choice["message"]["content"]) > 0
assert choice["finish_reason"] in ["stop", "length"]

# Verify token usage
assert "usage" in data
assert data["usage"]["prompt_tokens"] > 0
assert data["usage"]["completion_tokens"] > 0
assert data["usage"]["total_tokens"] == (
data["usage"]["prompt_tokens"] + data["usage"]["completion_tokens"]
)

def test_invalid_model_returns_404(self, api_server, api_client):
"""Test that invalid model returns 404 error"""
payload = {
Expand Down
Loading
Loading