From 54b4c6fa7764d7ce4473c5ea3b4105b51e1b8e1f Mon Sep 17 00:00:00 2001 From: Ovtcharov Date: Thu, 28 May 2026 19:53:47 -0700 Subject: [PATCH 1/3] test(agents,llm): cover RoutingAgent, code validators, OpenAI provider (#880) RoutingAgent, code validators (syntax, antipattern, AST, requirements), and OpenAIProvider had zero dedicated unit tests. Adds 129 tests covering init, routing logic, LLM JSON parsing, disambiguation, keyword fallback, all four validator classes, and the full OpenAI provider surface (chat, stream, generate, embed, unsupported methods). --- tests/unit/agents/test_code_validators.py | 478 ++++++++++++++++++++++ tests/unit/agents/test_routing_agent.py | 412 +++++++++++++++++++ tests/unit/test_openai_provider.py | 309 ++++++++++++++ 3 files changed, 1199 insertions(+) create mode 100644 tests/unit/agents/test_code_validators.py create mode 100644 tests/unit/agents/test_routing_agent.py create mode 100644 tests/unit/test_openai_provider.py diff --git a/tests/unit/agents/test_code_validators.py b/tests/unit/agents/test_code_validators.py new file mode 100644 index 000000000..a93b18c24 --- /dev/null +++ b/tests/unit/agents/test_code_validators.py @@ -0,0 +1,478 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Unit tests for code validators — syntax, antipattern, AST, and requirements.""" + +import ast +import textwrap +from pathlib import Path + +import pytest + +from gaia.agents.code.validators.antipattern_checker import AntipatternChecker +from gaia.agents.code.validators.ast_analyzer import ASTAnalyzer +from gaia.agents.code.validators.requirements_validator import RequirementsValidator +from gaia.agents.code.validators.syntax_validator import SyntaxValidator + +# =================================================================== +# SyntaxValidator +# =================================================================== + + +class TestSyntaxValidator: + """SyntaxValidator.validate / validate_dict / helpers.""" + + @pytest.fixture() + def validator(self): + return SyntaxValidator() + + # -- validate -- + + def test_valid_code(self, validator): + result = validator.validate("x = 1\nprint(x)\n") + assert result.is_valid is True + assert result.errors == [] + + def test_empty_code(self, validator): + result = validator.validate("") + assert result.is_valid is True + + def test_syntax_error_detected(self, validator): + result = validator.validate("def foo(\n") + assert result.is_valid is False + assert len(result.errors) > 0 + + def test_syntax_error_includes_line_number(self, validator): + result = validator.validate("x = 1\ndef foo(\n") + assert any("Line" in e for e in result.errors) + + # -- validate_dict -- + + def test_validate_dict_valid(self, validator): + d = validator.validate_dict("a = 1") + assert d["status"] == "success" + assert d["is_valid"] is True + assert d["message"] == "Syntax is valid" + + def test_validate_dict_invalid(self, validator): + d = validator.validate_dict("def (") + assert d["status"] == "error" + assert d["is_valid"] is False + assert len(d["errors"]) > 0 + + # -- get_syntax_errors -- + + def test_get_syntax_errors_none(self, validator): + assert validator.get_syntax_errors("x = 1") == [] + + def test_get_syntax_errors_returns_syntax_error(self, validator): + errors = validator.get_syntax_errors("def (") + assert len(errors) == 1 + assert isinstance(errors[0], SyntaxError) + + # -- check_indentation -- + + def test_indentation_clean(self, validator): + code = "def f():\n pass\n" + assert validator.check_indentation(code) == [] + + def test_indentation_mixed_tabs_spaces(self, validator): + code = "def f():\n \tpass\n" + warnings = validator.check_indentation(code) + assert any("Mixed tabs and spaces" in w for w in warnings) + + def test_indentation_non_standard(self, validator): + code = "def f():\n pass\n" + warnings = validator.check_indentation(code) + assert any("Non-standard indentation" in w for w in warnings) + + # -- validate_imports -- + + def test_validate_imports_clean(self, validator): + code = "import os\nimport sys\n" + assert validator.validate_imports(code) == [] + + def test_validate_imports_wildcard(self, validator): + code = "from os import *\n" + warnings = validator.validate_imports(code) + assert any("Wildcard import" in w for w in warnings) + + def test_validate_imports_duplicate(self, validator): + code = "import os\nimport os\n" + warnings = validator.validate_imports(code) + assert any("Duplicate import" in w for w in warnings) + + # -- check_line_length -- + + def test_line_length_ok(self, validator): + assert validator.check_line_length("x = 1") == [] + + def test_line_length_exceeded(self, validator): + long_line = "x = " + "a" * 90 + warnings = validator.check_line_length(long_line, max_length=88) + assert len(warnings) == 1 + assert "Line too long" in warnings[0] + + +# =================================================================== +# AntipatternChecker +# =================================================================== + + +class TestAntipatternChecker: + """AntipatternChecker.check / check_dict / naming / complexity.""" + + @pytest.fixture() + def checker(self): + return AntipatternChecker() + + # -- check -- + + def test_clean_code(self, checker): + code = textwrap.dedent("""\ + def greet(name): + print(f"hi {name}") + """) + result = checker.check(Path("clean.py"), code) + assert result["errors"] == [] + assert result["warnings"] == [] + + def test_excessive_function_name(self, checker): + name = "a" * 81 + code = f"def {name}():\n pass\n" + result = checker.check(Path("long.py"), code) + assert any("chars" in e for e in result["errors"]) + + def test_combinatorial_naming(self, checker): + code = "def get_and_process_and_validate_and_transform():\n pass\n" + result = checker.check(Path("combo.py"), code) + assert any("Combinatorial" in e for e in result["errors"]) + + def test_excessive_parameters(self, checker): + params = ", ".join(f"p{i}" for i in range(8)) + code = f"def func({params}):\n pass\n" + result = checker.check(Path("params.py"), code) + assert any("parameters" in w for w in result["warnings"]) + + def test_long_function_warns(self, checker): + body = "\n".join(f" x{i} = {i}" for i in range(55)) + code = f"def long_func():\n{body}\n" + result = checker.check(Path("long_func.py"), code) + assert any("lines long" in w for w in result["warnings"]) + + def test_duplicate_class_definitions(self, checker): + code = "class Foo:\n pass\nclass Foo:\n pass\n" + result = checker.check(Path("dup.py"), code) + assert any("Duplicate class" in e for e in result["errors"]) + + def test_excessive_file_length(self, checker): + code = "\n".join(f"x{i} = {i}" for i in range(1010)) + result = checker.check(Path("big.py"), code) + assert any("lines" in w and "splitting" in w for w in result["warnings"]) + + def test_syntax_error_ignored(self, checker): + result = checker.check(Path("bad.py"), "def (") + assert result["errors"] == [] + assert result["warnings"] == [] + + # -- check_dict -- + + def test_check_dict_delegates(self, checker): + code = "def greet():\n pass\n" + result = checker.check_dict(code) + assert "errors" in result + assert "warnings" in result + + # -- check_naming_patterns -- + + def test_naming_long_function(self, checker): + code = f"def {'a' * 45}():\n pass\n" + tree = ast.parse(code) + issues = checker.check_naming_patterns(tree) + assert any("long name" in i for i in issues) + + def test_naming_too_many_underscores(self, checker): + code = "def a_b_c_d_e_f_g():\n pass\n" + tree = ast.parse(code) + issues = checker.check_naming_patterns(tree) + assert any("underscores" in i for i in issues) + + def test_class_name_lowercase(self, checker): + code = "class myclass:\n pass\n" + tree = ast.parse(code) + issues = checker.check_naming_patterns(tree) + assert any("uppercase" in i for i in issues) + + def test_class_name_too_long(self, checker): + code = f"class {'A' * 35}:\n pass\n" + tree = ast.parse(code) + issues = checker.check_naming_patterns(tree) + assert any("long name" in i for i in issues) + + # -- check_function_complexity -- + + def test_deep_nesting(self, checker): + code = textwrap.dedent("""\ + def deep(): + if True: + for i in range(1): + while True: + if True: + with open("x"): + pass + """) + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + issues = checker.check_function_complexity(func) + assert any("nesting" in i for i in issues) + + def test_many_branches(self, checker): + ifs = "\n".join(f" if x == {i}: pass" for i in range(12)) + code = f"def branchy(x):\n{ifs}\n" + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + issues = checker.check_function_complexity(func) + assert any("branches" in i for i in issues) + + def test_many_loops(self, checker): + loops = "\n".join(f" for i{n} in range(1): pass" for n in range(5)) + code = f"def loopy():\n{loops}\n" + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + issues = checker.check_function_complexity(func) + assert any("loops" in i for i in issues) + + def test_simple_function_no_issues(self, checker): + code = "def ok(x):\n return x + 1\n" + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + assert checker.check_function_complexity(func) == [] + + +# =================================================================== +# ASTAnalyzer +# =================================================================== + + +class TestASTAnalyzer: + """ASTAnalyzer.parse_code, extract_*, get_docstring.""" + + @pytest.fixture() + def analyzer(self): + return ASTAnalyzer() + + # -- parse_code -- + + def test_parse_valid_code(self, analyzer): + code = textwrap.dedent("""\ + import os + + X = 42 + + def greet(name: str) -> str: + \"\"\"Say hello.\"\"\" + return f"hi {name}" + + class Foo: + \"\"\"A class.\"\"\" + pass + """) + parsed = analyzer.parse_code(code) + assert parsed.is_valid is True + assert parsed.errors == [] + + names = {s.name for s in parsed.symbols} + assert "greet" in names + assert "Foo" in names + assert "os" in names + assert "X" in names + + def test_parse_invalid_code(self, analyzer): + parsed = analyzer.parse_code("def (") + assert parsed.is_valid is False + assert len(parsed.errors) > 0 + + def test_parse_extracts_imports(self, analyzer): + code = "import os\nfrom pathlib import Path\n" + parsed = analyzer.parse_code(code) + assert "import os" in parsed.imports + assert "from pathlib import Path" in parsed.imports + + def test_function_signature_with_types(self, analyzer): + code = "def add(a: int, b: int) -> int:\n return a + b\n" + parsed = analyzer.parse_code(code) + func_sym = [s for s in parsed.symbols if s.name == "add"][0] + assert "a: int" in func_sym.signature + assert "-> int" in func_sym.signature + + def test_function_signature_varargs(self, analyzer): + code = "def f(*args, **kwargs):\n pass\n" + parsed = analyzer.parse_code(code) + func_sym = [s for s in parsed.symbols if s.name == "f"][0] + assert "*args" in func_sym.signature + assert "**kwargs" in func_sym.signature + + def test_async_function_detected(self, analyzer): + code = "async def fetch():\n pass\n" + parsed = analyzer.parse_code(code) + names = {s.name for s in parsed.symbols if s.type == "function"} + assert "fetch" in names + + def test_class_docstring_extracted(self, analyzer): + code = 'class Foo:\n """Foo docs."""\n pass\n' + parsed = analyzer.parse_code(code) + cls = [s for s in parsed.symbols if s.name == "Foo"][0] + assert cls.docstring == "Foo docs." + + def test_module_level_variable(self, analyzer): + code = "MY_CONST = 42\n" + parsed = analyzer.parse_code(code) + var = [s for s in parsed.symbols if s.name == "MY_CONST"] + assert len(var) == 1 + assert var[0].type == "variable" + + # -- extract_functions / extract_classes -- + + def test_extract_functions(self, analyzer): + code = "def a():\n pass\ndef b():\n pass\n" + tree = ast.parse(code) + funcs = analyzer.extract_functions(tree) + assert len(funcs) == 2 + + def test_extract_classes(self, analyzer): + code = "class A:\n pass\nclass B:\n pass\n" + tree = ast.parse(code) + classes = analyzer.extract_classes(tree) + assert len(classes) == 2 + + # -- get_docstring -- + + def test_get_docstring_from_function(self, analyzer): + code = 'def f():\n """Hello."""\n pass\n' + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + assert analyzer.get_docstring(func) == "Hello." + + def test_get_docstring_none(self, analyzer): + code = "def f():\n pass\n" + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + assert analyzer.get_docstring(func) is None + + +# =================================================================== +# RequirementsValidator +# =================================================================== + + +class TestRequirementsValidator: + """RequirementsValidator.validate, check_package_validity, suggest_common_packages.""" + + @pytest.fixture() + def validator(self): + return RequirementsValidator() + + # -- validate -- + + def test_valid_requirements(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("flask==2.3.0\nrequests>=2.28\n") + result = validator.validate(req) + assert result["is_valid"] is True + assert result["packages"] == 2 + assert result["errors"] == [] + + def test_hallucinated_package_detected(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("flask-graphql-a-b-c-d-e\n") + result = validator.validate(req) + assert result["is_valid"] is False + assert any("Hallucinated" in e for e in result["errors"]) + + def test_recursive_ibm_pattern(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("some-ibm-cloud-ibm-cloud-sdk\n") + result = validator.validate(req) + assert result["is_valid"] is False + + def test_package_name_too_long(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("a" * 65 + "\n") + result = validator.validate(req) + assert result["is_valid"] is False + assert any("too long" in e for e in result["errors"]) + + def test_duplicate_package_warning(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("flask\nflask\n") + result = validator.validate(req) + assert any("Duplicate" in w for w in result["warnings"]) + + def test_comment_lines_ignored(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("# comment\nflask\n") + result = validator.validate(req) + assert result["is_valid"] is True + assert result["packages"] == 1 + + def test_empty_lines_ignored(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("flask\n\nrequests\n") + result = validator.validate(req) + assert result["packages"] == 2 + + def test_many_packages_warning(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + lines = [f"pkg{i}" for i in range(35)] + req.write_text("\n".join(lines)) + result = validator.validate(req) + assert any("Many packages" in w for w in result["warnings"]) + + def test_too_many_packages_error(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + lines = [f"pkg{i}" for i in range(55)] + req.write_text("\n".join(lines)) + result = validator.validate(req) + assert any("Too many" in e for e in result["errors"]) + + def test_auto_fix_removes_bad_packages(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("flask\nflask-graphql-a-b-c-d-e\nrequests\n") + result = validator.validate(req, fix=True) + assert result["fixed_content"] is not None + assert "flask-graphql" not in result["fixed_content"] + assert "flask" in result["fixed_content"] + assert "requests" in result["fixed_content"] + + # -- check_package_validity -- + + def test_valid_package_name(self, validator): + assert validator.check_package_validity("flask") is True + assert validator.check_package_validity("scikit-learn") is True + assert validator.check_package_validity("python-dotenv") is True + + def test_hallucinated_package_invalid(self, validator): + assert validator.check_package_validity("x-ibm-cloud-ibm-cloud-y") is False + + def test_too_long_package_invalid(self, validator): + assert validator.check_package_validity("a" * 61) is False + + def test_invalid_chars_package(self, validator): + assert validator.check_package_validity("flask@latest") is False + + def test_package_starting_with_hyphen(self, validator): + assert validator.check_package_validity("-flask") is False + + # -- suggest_common_packages -- + + def test_suggest_web(self, validator): + pkgs = validator.suggest_common_packages("web") + assert "flask" in pkgs + assert "django" in pkgs + + def test_suggest_ml(self, validator): + pkgs = validator.suggest_common_packages("ml") + assert "torch" in pkgs + + def test_suggest_unknown_falls_back(self, validator): + pkgs = validator.suggest_common_packages("gaming") + assert pkgs == validator.suggest_common_packages("general") diff --git a/tests/unit/agents/test_routing_agent.py b/tests/unit/agents/test_routing_agent.py new file mode 100644 index 000000000..3c14517fa --- /dev/null +++ b/tests/unit/agents/test_routing_agent.py @@ -0,0 +1,412 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Unit tests for RoutingAgent — routing logic, disambiguation, and agent creation.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def mock_llm_client(): + """Return a mocked LLM client that responds with valid routing JSON.""" + client = MagicMock() + client.generate.return_value = json.dumps( + { + "agent": "code", + "parameters": {"language": "typescript", "project_type": "fullstack"}, + "confidence": 0.95, + "reasoning": "Next.js detected", + } + ) + return client + + +@pytest.fixture() +def _patch_create_client(mock_llm_client): + """Patch create_client so RoutingAgent.__init__ uses the mock LLM.""" + with patch("gaia.agents.routing.agent.create_client", return_value=mock_llm_client): + yield + + +@pytest.fixture() +def _patch_code_agent(): + """Patch CodeAgent so _create_agent never touches the real agent stack.""" + with patch("gaia.agents.code.agent.CodeAgent") as cls: + cls.return_value = MagicMock() + yield cls + + +@pytest.fixture() +def router(_patch_create_client): + """Return a RoutingAgent wired to the mock LLM.""" + from gaia.agents.routing.agent import RoutingAgent + + return RoutingAgent(api_mode=True) + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +class TestRoutingAgentInit: + """Constructor and configuration.""" + + def test_import_and_exposes_process_query(self): + from gaia.agents.routing.agent import RoutingAgent + + assert hasattr(RoutingAgent, "process_query") + + def test_default_routing_model(self, router): + assert router.routing_model == "Qwen3.5-35B-A3B-GGUF" + + def test_custom_routing_model_via_env(self, _patch_create_client, monkeypatch): + monkeypatch.setenv("AGENT_ROUTING_MODEL", "custom-model") + from gaia.agents.routing.agent import RoutingAgent + + r = RoutingAgent(api_mode=True) + assert r.routing_model == "custom-model" + + def test_api_mode_stored(self, _patch_create_client): + from gaia.agents.routing.agent import RoutingAgent + + r = RoutingAgent(api_mode=True) + assert r.api_mode is True + + def test_cli_mode_default(self, _patch_create_client): + from gaia.agents.routing.agent import RoutingAgent + + r = RoutingAgent() + assert r.api_mode is False + + def test_agent_kwargs_stored(self, _patch_create_client): + from gaia.agents.routing.agent import RoutingAgent + + r = RoutingAgent(api_mode=True, foo="bar") + assert r.agent_kwargs["foo"] == "bar" + + +# --------------------------------------------------------------------------- +# LLM analysis +# --------------------------------------------------------------------------- + + +class TestAnalyzeWithLLM: + """_analyze_with_llm parses LLM JSON correctly.""" + + def test_parses_clean_json(self, router, mock_llm_client): + result = router._analyze_with_llm( + [{"role": "user", "content": "Create a Next.js app"}] + ) + assert result["agent"] == "code" + assert result["parameters"]["language"] == "typescript" + assert result["confidence"] == 0.95 + + def test_parses_json_in_markdown_code_block(self, router, mock_llm_client): + mock_llm_client.generate.return_value = ( + '```json\n{"agent":"code","parameters":{"language":"python",' + '"project_type":"api"},"confidence":0.9,"reasoning":"Flask"}\n```' + ) + result = router._analyze_with_llm( + [{"role": "user", "content": "Build a Flask API"}] + ) + assert result["parameters"]["language"] == "python" + + def test_parses_json_in_generic_code_block(self, router, mock_llm_client): + mock_llm_client.generate.return_value = ( + '```\n{"agent":"code","parameters":{"language":"python",' + '"project_type":"script"},"confidence":0.8,"reasoning":"generic"}\n```' + ) + result = router._analyze_with_llm( + [{"role": "user", "content": "Write a script"}] + ) + assert result["parameters"]["project_type"] == "script" + + def test_json_parse_failure_returns_fallback(self, router, mock_llm_client): + mock_llm_client.generate.return_value = "NOT VALID JSON AT ALL" + result = router._analyze_with_llm([{"role": "user", "content": "whatever"}]) + assert result["confidence"] == 0.0 + assert result["parameters"]["language"] == "unknown" + + def test_llm_exception_propagates(self, router, mock_llm_client): + mock_llm_client.generate.side_effect = ConnectionError("offline") + with pytest.raises(RuntimeError, match="Failed to analyze query"): + router._analyze_with_llm([{"role": "user", "content": "x"}]) + + +# --------------------------------------------------------------------------- +# has_unknowns +# --------------------------------------------------------------------------- + + +class TestHasUnknowns: + """_has_unknowns detects missing or low-confidence parameters.""" + + def test_no_unknowns_high_confidence(self, router): + analysis = { + "parameters": {"language": "typescript", "project_type": "fullstack"}, + "confidence": 0.95, + } + assert router._has_unknowns(analysis) is False + + def test_unknown_language(self, router): + analysis = { + "parameters": {"language": "unknown", "project_type": "fullstack"}, + "confidence": 0.95, + } + assert router._has_unknowns(analysis) is True + + def test_unknown_project_type(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "unknown"}, + "confidence": 0.95, + } + assert router._has_unknowns(analysis) is True + + def test_low_confidence_triggers_unknowns(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "api"}, + "confidence": 0.5, + } + assert router._has_unknowns(analysis) is True + + def test_boundary_confidence_0_9(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "api"}, + "confidence": 0.9, + } + assert router._has_unknowns(analysis) is False + + +# --------------------------------------------------------------------------- +# Clarification questions +# --------------------------------------------------------------------------- + + +class TestClarificationQuestions: + """_generate_clarification_question returns context-appropriate prompts.""" + + def test_both_unknown(self, router): + analysis = { + "parameters": {"language": "unknown", "project_type": "unknown"}, + } + q = router._generate_clarification_question(analysis) + assert "What kind of application" in q + + def test_language_unknown_fullstack(self, router): + analysis = { + "parameters": {"language": "unknown", "project_type": "fullstack"}, + } + q = router._generate_clarification_question(analysis) + assert "language" in q.lower() or "framework" in q.lower() + + def test_language_unknown_script(self, router): + analysis = { + "parameters": {"language": "unknown", "project_type": "script"}, + } + q = router._generate_clarification_question(analysis) + assert "language" in q.lower() + + def test_project_type_unknown_typescript(self, router): + analysis = { + "parameters": {"language": "typescript", "project_type": "unknown"}, + } + q = router._generate_clarification_question(analysis) + assert "TypeScript" in q + + def test_project_type_unknown_python(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "unknown"}, + } + q = router._generate_clarification_question(analysis) + assert "Python" in q + + +# --------------------------------------------------------------------------- +# Keyword fallback detection +# --------------------------------------------------------------------------- + + +class TestFallbackKeywordDetection: + """_fallback_keyword_detection finds language from framework keywords.""" + + @pytest.mark.parametrize( + "query, expected_lang", + [ + ("Build a Next.js blog", "typescript"), + ("Create a React dashboard", "typescript"), + ("Express REST API", "typescript"), + ("Angular admin panel", "typescript"), + ("Svelte app", "typescript"), + ], + ) + def test_typescript_keywords(self, router, query, expected_lang): + result = router._fallback_keyword_detection(query) + assert result["parameters"]["language"] == expected_lang + + @pytest.mark.parametrize( + "query, expected_lang", + [ + ("Django REST API", "python"), + ("Flask microservice", "python"), + ("FastAPI server", "python"), + ("Pandas data analysis", "python"), + ], + ) + def test_python_keywords(self, router, query, expected_lang): + result = router._fallback_keyword_detection(query) + assert result["parameters"]["language"] == expected_lang + + def test_no_keywords_returns_unknown(self, router): + result = router._fallback_keyword_detection("Build something cool") + assert result["parameters"]["language"] == "unknown" + + def test_typescript_cli_is_script(self, router): + result = router._fallback_keyword_detection("Node.js CLI tool") + assert result["parameters"]["project_type"] == "script" + + def test_python_api_project_type(self, router): + result = router._fallback_keyword_detection("FastAPI REST backend") + assert result["parameters"]["project_type"] == "api" + + +# --------------------------------------------------------------------------- +# Default-to-TypeScript logic +# --------------------------------------------------------------------------- + + +class TestDefaultUnknownLanguage: + """_default_unknown_language_to_typescript fills unknowns.""" + + def test_unknown_language_becomes_typescript(self, router): + analysis = { + "parameters": {"language": "unknown", "project_type": "unknown"}, + "confidence": 0.5, + "reasoning": "ambiguous", + } + result = router._default_unknown_language_to_typescript(analysis) + assert result["parameters"]["language"] == "typescript" + assert result["parameters"]["project_type"] == "fullstack" + assert result["confidence"] == 1.0 + + def test_known_language_unchanged(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "api"}, + "confidence": 0.9, + "reasoning": "clear", + } + result = router._default_unknown_language_to_typescript(analysis) + assert result["parameters"]["language"] == "python" + assert result["confidence"] == 0.9 + + +# --------------------------------------------------------------------------- +# enforce_typescript_only +# --------------------------------------------------------------------------- + + +class TestEnforceTypescriptOnly: + """_enforce_typescript_only rejects non-TS languages.""" + + def test_typescript_fullstack_passes(self, router): + console = MagicMock() + lang, pt = router._enforce_typescript_only("typescript", "fullstack", console) + assert lang == "typescript" + assert pt == "fullstack" + + def test_python_raises_system_exit(self, router): + console = MagicMock() + with pytest.raises(SystemExit): + router._enforce_typescript_only("python", "script", console) + console.print_error.assert_called_once() + + +# --------------------------------------------------------------------------- +# Agent creation +# --------------------------------------------------------------------------- + + +class TestCreateAgent: + """_create_agent produces a CodeAgent with correct params.""" + + def test_creates_code_agent(self, router, _patch_code_agent): + analysis = { + "agent": "code", + "parameters": {"language": "typescript", "project_type": "fullstack"}, + } + agent = router._create_agent(analysis) + assert agent is not None + _patch_code_agent.assert_called_once() + call_kwargs = _patch_code_agent.call_args + assert ( + call_kwargs.kwargs.get("language") == "typescript" + or call_kwargs[1].get("language") == "typescript" + ) + + def test_unknown_agent_type_raises(self, router): + analysis = { + "agent": "unknown_agent", + "parameters": {}, + } + with pytest.raises(ValueError, match="Unknown agent type"): + router._create_agent(analysis) + + +class TestCreateAgentWithDefaults: + """_create_agent_with_defaults fills unknowns before creating.""" + + def test_unknown_typescript_defaults_to_fullstack(self, router, _patch_code_agent): + analysis = { + "parameters": {"language": "typescript", "project_type": "unknown"}, + } + router._create_agent_with_defaults(analysis) + call_kwargs = _patch_code_agent.call_args + assert ( + call_kwargs.kwargs.get("project_type") == "fullstack" + or call_kwargs[1].get("project_type") == "fullstack" + ) + + +# --------------------------------------------------------------------------- +# process_query end-to-end (API mode) +# --------------------------------------------------------------------------- + + +class TestProcessQueryAPIMode: + """process_query in API mode auto-executes the agent.""" + + def test_resolved_query_executes(self, router, mock_llm_client, _patch_code_agent): + mock_agent = _patch_code_agent.return_value + mock_agent.process_query.return_value = "done" + + result = router.process_query("Create a Next.js blog") + assert result == "done" + mock_agent.process_query.assert_called_once() + + def test_execute_false_returns_agent( + self, router, mock_llm_client, _patch_code_agent + ): + agent = router.process_query("Create a Next.js blog", execute=False) + assert agent is _patch_code_agent.return_value + + def test_low_confidence_api_mode_uses_defaults( + self, router, mock_llm_client, _patch_code_agent + ): + mock_llm_client.generate.return_value = json.dumps( + { + "agent": "code", + "parameters": {"language": "typescript", "project_type": "unknown"}, + "confidence": 0.4, + "reasoning": "unclear", + } + ) + mock_agent = _patch_code_agent.return_value + mock_agent.process_query.return_value = "ok" + result = router.process_query("Build something") + assert result == "ok" diff --git a/tests/unit/test_openai_provider.py b/tests/unit/test_openai_provider.py new file mode 100644 index 000000000..9ccf94d75 --- /dev/null +++ b/tests/unit/test_openai_provider.py @@ -0,0 +1,309 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Unit tests for OpenAIProvider — chat, generate, embed, stream, errors.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from gaia.llm.exceptions import NotSupportedError + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def mock_openai_module(): + """Patch the openai module so OpenAIProvider never hits the network.""" + mock_mod = MagicMock() + mock_client_instance = MagicMock() + mock_mod.OpenAI.return_value = mock_client_instance + with patch.dict("sys.modules", {"openai": mock_mod}): + yield mock_mod, mock_client_instance + + +@pytest.fixture() +def provider(mock_openai_module): + """Return an OpenAIProvider backed by the mocked openai module.""" + from gaia.llm.providers.openai_provider import OpenAIProvider + + return OpenAIProvider(api_key="sk-test", model="gpt-4o") + + +@pytest.fixture() +def client(mock_openai_module): + """Shortcut to the mocked openai.OpenAI() instance.""" + return mock_openai_module[1] + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +class TestOpenAIProviderInit: + """Constructor and basic properties.""" + + def test_provider_name(self, provider): + assert provider.provider_name == "OpenAI" + + def test_default_model(self, mock_openai_module): + from gaia.llm.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="sk-test") + assert p._model == "gpt-4o" + + def test_custom_model(self, mock_openai_module): + from gaia.llm.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="sk-test", model="gpt-4-turbo") + assert p._model == "gpt-4-turbo" + + def test_system_prompt_stored(self, mock_openai_module): + from gaia.llm.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="sk-test", system_prompt="You are helpful.") + assert p._system_prompt == "You are helpful." + + def test_extra_kwargs_ignored(self, mock_openai_module): + from gaia.llm.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="sk-test", base_url="http://x", unknown_arg=True) + assert p._model == "gpt-4o" + + +# --------------------------------------------------------------------------- +# chat() — non-streaming +# --------------------------------------------------------------------------- + + +class TestChat: + """chat() delegates to OpenAI SDK and returns content.""" + + def test_returns_message_content(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Hello!" + client.chat.completions.create.return_value = mock_response + + result = provider.chat([{"role": "user", "content": "Hi"}], stream=False) + assert result == "Hello!" + + def test_uses_default_model(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + client.chat.completions.create.return_value = mock_response + + provider.chat([{"role": "user", "content": "Hi"}]) + call_kwargs = client.chat.completions.create.call_args + assert call_kwargs.kwargs["model"] == "gpt-4o" + + def test_model_override(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + client.chat.completions.create.return_value = mock_response + + provider.chat([{"role": "user", "content": "Hi"}], model="gpt-4-turbo") + call_kwargs = client.chat.completions.create.call_args + assert call_kwargs.kwargs["model"] == "gpt-4-turbo" + + def test_system_prompt_prepended(self, mock_openai_module): + from gaia.llm.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="sk-test", system_prompt="Be concise.") + _, mock_client = mock_openai_module + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + mock_client.chat.completions.create.return_value = mock_response + + p.chat([{"role": "user", "content": "Hi"}]) + call_kwargs = mock_client.chat.completions.create.call_args + messages = call_kwargs.kwargs["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Be concise." + assert messages[1]["role"] == "user" + + def test_no_system_prompt_by_default(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + client.chat.completions.create.return_value = mock_response + + provider.chat([{"role": "user", "content": "Hi"}]) + call_kwargs = client.chat.completions.create.call_args + messages = call_kwargs.kwargs["messages"] + assert messages[0]["role"] == "user" + + def test_extra_kwargs_passed_through(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + client.chat.completions.create.return_value = mock_response + + provider.chat( + [{"role": "user", "content": "Hi"}], + temperature=0.5, + max_tokens=100, + ) + call_kwargs = client.chat.completions.create.call_args + assert call_kwargs.kwargs["temperature"] == 0.5 + assert call_kwargs.kwargs["max_tokens"] == 100 + + +# --------------------------------------------------------------------------- +# chat() — streaming +# --------------------------------------------------------------------------- + + +class TestChatStreaming: + """chat(stream=True) returns an iterator of text chunks.""" + + def _make_chunk(self, content): + chunk = MagicMock() + chunk.choices = [MagicMock()] + chunk.choices[0].delta.content = content + return chunk + + def _make_empty_chunk(self): + chunk = MagicMock() + chunk.choices = [MagicMock()] + chunk.choices[0].delta.content = None + return chunk + + def _make_no_choices_chunk(self): + chunk = MagicMock() + chunk.choices = [] + return chunk + + def test_stream_yields_content(self, provider, client): + chunks = [self._make_chunk("Hello"), self._make_chunk(" world")] + client.chat.completions.create.return_value = iter(chunks) + + result = provider.chat([{"role": "user", "content": "Hi"}], stream=True) + pieces = list(result) + assert pieces == ["Hello", " world"] + + def test_stream_skips_empty_deltas(self, provider, client): + chunks = [ + self._make_chunk("A"), + self._make_empty_chunk(), + self._make_chunk("B"), + ] + client.chat.completions.create.return_value = iter(chunks) + + pieces = list(provider.chat([{"role": "user", "content": "Hi"}], stream=True)) + assert pieces == ["A", "B"] + + def test_stream_skips_no_choices(self, provider, client): + chunks = [ + self._make_chunk("X"), + self._make_no_choices_chunk(), + self._make_chunk("Y"), + ] + client.chat.completions.create.return_value = iter(chunks) + + pieces = list(provider.chat([{"role": "user", "content": "Hi"}], stream=True)) + assert pieces == ["X", "Y"] + + +# --------------------------------------------------------------------------- +# generate() +# --------------------------------------------------------------------------- + + +class TestGenerate: + """generate() wraps prompt into a user message and delegates to chat().""" + + def test_generate_non_streaming(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "42" + client.chat.completions.create.return_value = mock_response + + result = provider.generate("What is 6*7?") + assert result == "42" + + call_kwargs = client.chat.completions.create.call_args + messages = call_kwargs.kwargs["messages"] + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == "What is 6*7?" + + def test_generate_with_model_override(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + client.chat.completions.create.return_value = mock_response + + provider.generate("test", model="gpt-3.5-turbo") + call_kwargs = client.chat.completions.create.call_args + assert call_kwargs.kwargs["model"] == "gpt-3.5-turbo" + + +# --------------------------------------------------------------------------- +# embed() +# --------------------------------------------------------------------------- + + +class TestEmbed: + """embed() returns a list of embedding vectors.""" + + def test_embed_returns_vectors(self, provider, client): + item1 = MagicMock() + item1.embedding = [0.1, 0.2, 0.3] + item2 = MagicMock() + item2.embedding = [0.4, 0.5, 0.6] + mock_response = MagicMock() + mock_response.data = [item1, item2] + client.embeddings.create.return_value = mock_response + + result = provider.embed(["hello", "world"]) + assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + + def test_embed_uses_default_model(self, provider, client): + mock_response = MagicMock() + mock_response.data = [] + client.embeddings.create.return_value = mock_response + + provider.embed(["text"]) + call_kwargs = client.embeddings.create.call_args + assert call_kwargs.kwargs["model"] == "text-embedding-3-small" + + def test_embed_custom_model(self, provider, client): + mock_response = MagicMock() + mock_response.data = [] + client.embeddings.create.return_value = mock_response + + provider.embed(["text"], model="text-embedding-ada-002") + call_kwargs = client.embeddings.create.call_args + assert call_kwargs.kwargs["model"] == "text-embedding-ada-002" + + +# --------------------------------------------------------------------------- +# Unsupported methods (inherited from LLMClient) +# --------------------------------------------------------------------------- + + +class TestUnsupportedMethods: + """Methods not implemented by OpenAIProvider raise NotSupportedError.""" + + def test_vision_not_supported(self, provider): + with pytest.raises(NotSupportedError, match="OpenAI.*vision"): + provider.vision([b"img"], "describe") + + def test_get_performance_stats_not_supported(self, provider): + with pytest.raises(NotSupportedError, match="OpenAI.*get_performance_stats"): + provider.get_performance_stats() + + def test_load_model_not_supported(self, provider): + with pytest.raises(NotSupportedError, match="OpenAI.*load_model"): + provider.load_model("gpt-4o") + + def test_unload_model_not_supported(self, provider): + with pytest.raises(NotSupportedError, match="OpenAI.*unload_model"): + provider.unload_model() From dca78fadd9b343cec737cbc2ca0f63f24e8f6f31 Mon Sep 17 00:00:00 2001 From: Ovtcharov Date: Thu, 28 May 2026 20:01:13 -0700 Subject: [PATCH 2/3] test: strengthen assertions after review Adds boundary test at 0.89 confidence, CLI-mode interactive path coverage (mock input()), console.print_info verification, generate(stream=True), SDK exception propagation, and tighter embed/init assertions. --- tests/unit/agents/test_routing_agent.py | 93 ++++++++++++++++++++++++- tests/unit/test_openai_provider.py | 23 +++++- 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/tests/unit/agents/test_routing_agent.py b/tests/unit/agents/test_routing_agent.py index 3c14517fa..3f2ff90e6 100644 --- a/tests/unit/agents/test_routing_agent.py +++ b/tests/unit/agents/test_routing_agent.py @@ -176,13 +176,20 @@ def test_low_confidence_triggers_unknowns(self, router): } assert router._has_unknowns(analysis) is True - def test_boundary_confidence_0_9(self, router): + def test_boundary_confidence_0_9_not_unknown(self, router): analysis = { "parameters": {"language": "python", "project_type": "api"}, "confidence": 0.9, } assert router._has_unknowns(analysis) is False + def test_boundary_confidence_0_89_is_unknown(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "api"}, + "confidence": 0.89, + } + assert router._has_unknowns(analysis) is True + # --------------------------------------------------------------------------- # Clarification questions @@ -305,6 +312,18 @@ def test_known_language_unchanged(self, router): assert result["parameters"]["language"] == "python" assert result["confidence"] == 0.9 + def test_prints_info_when_defaulting(self, router): + with patch.object(router, "_get_console") as mock_get: + mock_console = MagicMock() + mock_get.return_value = mock_console + analysis = { + "parameters": {"language": "unknown", "project_type": "unknown"}, + "confidence": 0.5, + "reasoning": "", + } + router._default_unknown_language_to_typescript(analysis) + mock_console.print_info.assert_called_once() + # --------------------------------------------------------------------------- # enforce_typescript_only @@ -410,3 +429,75 @@ def test_low_confidence_api_mode_uses_defaults( mock_agent.process_query.return_value = "ok" result = router.process_query("Build something") assert result == "ok" + + +# --------------------------------------------------------------------------- +# process_query — CLI (interactive) mode +# --------------------------------------------------------------------------- + + +class TestProcessQueryCLIMode: + """process_query in CLI mode asks for clarification via input().""" + + def test_cli_mode_asks_clarification_then_resolves( + self, _patch_create_client, mock_llm_client, _patch_code_agent + ): + from gaia.agents.routing.agent import RoutingAgent + + router = RoutingAgent(api_mode=False) + + # First call: low confidence → triggers clarification + # Second call (recursive): high confidence → resolves + mock_llm_client.generate.side_effect = [ + json.dumps( + { + "agent": "code", + "parameters": { + "language": "unknown", + "project_type": "unknown", + }, + "confidence": 0.3, + "reasoning": "ambiguous", + } + ), + json.dumps( + { + "agent": "code", + "parameters": { + "language": "typescript", + "project_type": "fullstack", + }, + "confidence": 0.95, + "reasoning": "user clarified", + } + ), + ] + + with patch("builtins.input", return_value="Next.js blog"): + agent = router.process_query("Build something", execute=False) + + assert agent is _patch_code_agent.return_value + + def test_cli_mode_empty_response_uses_defaults( + self, _patch_create_client, mock_llm_client, _patch_code_agent + ): + from gaia.agents.routing.agent import RoutingAgent + + router = RoutingAgent(api_mode=False) + + mock_llm_client.generate.return_value = json.dumps( + { + "agent": "code", + "parameters": { + "language": "typescript", + "project_type": "unknown", + }, + "confidence": 0.4, + "reasoning": "unclear", + } + ) + + with patch("builtins.input", return_value=""): + agent = router.process_query("Build something", execute=False) + + assert agent is _patch_code_agent.return_value diff --git a/tests/unit/test_openai_provider.py b/tests/unit/test_openai_provider.py index 9ccf94d75..a0b8c290b 100644 --- a/tests/unit/test_openai_provider.py +++ b/tests/unit/test_openai_provider.py @@ -66,11 +66,12 @@ def test_system_prompt_stored(self, mock_openai_module): p = OpenAIProvider(api_key="sk-test", system_prompt="You are helpful.") assert p._system_prompt == "You are helpful." - def test_extra_kwargs_ignored(self, mock_openai_module): + def test_extra_kwargs_not_passed_to_openai_client(self, mock_openai_module): from gaia.llm.providers.openai_provider import OpenAIProvider - p = OpenAIProvider(api_key="sk-test", base_url="http://x", unknown_arg=True) - assert p._model == "gpt-4o" + mock_mod, _ = mock_openai_module + OpenAIProvider(api_key="sk-test", base_url="http://x", unknown_arg=True) + mock_mod.OpenAI.assert_called_once_with(api_key="sk-test") # --------------------------------------------------------------------------- @@ -154,6 +155,11 @@ def test_extra_kwargs_passed_through(self, provider, client): assert call_kwargs.kwargs["temperature"] == 0.5 assert call_kwargs.kwargs["max_tokens"] == 100 + def test_sdk_exception_propagates(self, provider, client): + client.chat.completions.create.side_effect = Exception("rate limited") + with pytest.raises(Exception, match="rate limited"): + provider.chat([{"role": "user", "content": "Hi"}]) + # --------------------------------------------------------------------------- # chat() — streaming @@ -244,6 +250,15 @@ def test_generate_with_model_override(self, provider, client): call_kwargs = client.chat.completions.create.call_args assert call_kwargs.kwargs["model"] == "gpt-3.5-turbo" + def test_generate_streaming(self, provider, client): + chunk = MagicMock() + chunk.choices = [MagicMock()] + chunk.choices[0].delta.content = "streamed" + client.chat.completions.create.return_value = iter([chunk]) + + result = provider.generate("test", stream=True) + assert list(result) == ["streamed"] + # --------------------------------------------------------------------------- # embed() @@ -264,6 +279,8 @@ def test_embed_returns_vectors(self, provider, client): result = provider.embed(["hello", "world"]) assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + call_kwargs = client.embeddings.create.call_args + assert call_kwargs.kwargs["input"] == ["hello", "world"] def test_embed_uses_default_model(self, provider, client): mock_response = MagicMock() From 06250c8b866d9f93054f57ab0e719c04fb64ffa5 Mon Sep 17 00:00:00 2001 From: Ovtcharov Date: Thu, 28 May 2026 23:45:04 -0700 Subject: [PATCH 3/3] =?UTF-8?q?fix(test):=20address=20PR=20#1244=20review?= =?UTF-8?q?=20=E2=80=94=20clarification=20loop=20+=20style?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The CLI-mode clarification test was passing via `_create_agent` (the default-to-typescript shortcut), never entering the `input()` branch. First LLM response now uses known language + unknown project_type + low confidence so `_default_unknown_language_to_typescript` leaves it alone and `_has_unknowns` fires. Asserts `mock_input.assert_called_once()`. Also drops redundant `call_kwargs[1].get()` fallbacks — `.kwargs` is the modern accessor and sufficient here. --- tests/unit/agents/test_routing_agent.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/unit/agents/test_routing_agent.py b/tests/unit/agents/test_routing_agent.py index 3f2ff90e6..6d4172342 100644 --- a/tests/unit/agents/test_routing_agent.py +++ b/tests/unit/agents/test_routing_agent.py @@ -363,10 +363,7 @@ def test_creates_code_agent(self, router, _patch_code_agent): assert agent is not None _patch_code_agent.assert_called_once() call_kwargs = _patch_code_agent.call_args - assert ( - call_kwargs.kwargs.get("language") == "typescript" - or call_kwargs[1].get("language") == "typescript" - ) + assert call_kwargs.kwargs.get("language") == "typescript" def test_unknown_agent_type_raises(self, router): analysis = { @@ -386,10 +383,7 @@ def test_unknown_typescript_defaults_to_fullstack(self, router, _patch_code_agen } router._create_agent_with_defaults(analysis) call_kwargs = _patch_code_agent.call_args - assert ( - call_kwargs.kwargs.get("project_type") == "fullstack" - or call_kwargs[1].get("project_type") == "fullstack" - ) + assert call_kwargs.kwargs.get("project_type") == "fullstack" # --------------------------------------------------------------------------- @@ -446,17 +440,20 @@ def test_cli_mode_asks_clarification_then_resolves( router = RoutingAgent(api_mode=False) - # First call: low confidence → triggers clarification - # Second call (recursive): high confidence → resolves + # First response: known language but unknown project_type + low + # confidence. _default_unknown_language_to_typescript leaves this + # unchanged (language != "unknown"), so _has_unknowns fires and the + # clarification branch runs. + # Second response (after user answers): fully resolved. mock_llm_client.generate.side_effect = [ json.dumps( { "agent": "code", "parameters": { - "language": "unknown", + "language": "typescript", "project_type": "unknown", }, - "confidence": 0.3, + "confidence": 0.4, "reasoning": "ambiguous", } ), @@ -473,9 +470,10 @@ def test_cli_mode_asks_clarification_then_resolves( ), ] - with patch("builtins.input", return_value="Next.js blog"): + with patch("builtins.input", return_value="Next.js blog") as mock_input: agent = router.process_query("Build something", execute=False) + mock_input.assert_called_once() assert agent is _patch_code_agent.return_value def test_cli_mode_empty_response_uses_defaults(