diff --git a/tests/unit/agents/test_code_validators.py b/tests/unit/agents/test_code_validators.py new file mode 100644 index 000000000..a93b18c24 --- /dev/null +++ b/tests/unit/agents/test_code_validators.py @@ -0,0 +1,478 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Unit tests for code validators — syntax, antipattern, AST, and requirements.""" + +import ast +import textwrap +from pathlib import Path + +import pytest + +from gaia.agents.code.validators.antipattern_checker import AntipatternChecker +from gaia.agents.code.validators.ast_analyzer import ASTAnalyzer +from gaia.agents.code.validators.requirements_validator import RequirementsValidator +from gaia.agents.code.validators.syntax_validator import SyntaxValidator + +# =================================================================== +# SyntaxValidator +# =================================================================== + + +class TestSyntaxValidator: + """SyntaxValidator.validate / validate_dict / helpers.""" + + @pytest.fixture() + def validator(self): + return SyntaxValidator() + + # -- validate -- + + def test_valid_code(self, validator): + result = validator.validate("x = 1\nprint(x)\n") + assert result.is_valid is True + assert result.errors == [] + + def test_empty_code(self, validator): + result = validator.validate("") + assert result.is_valid is True + + def test_syntax_error_detected(self, validator): + result = validator.validate("def foo(\n") + assert result.is_valid is False + assert len(result.errors) > 0 + + def test_syntax_error_includes_line_number(self, validator): + result = validator.validate("x = 1\ndef foo(\n") + assert any("Line" in e for e in result.errors) + + # -- validate_dict -- + + def test_validate_dict_valid(self, validator): + d = validator.validate_dict("a = 1") + assert d["status"] == "success" + assert d["is_valid"] is True + assert d["message"] == "Syntax is valid" + + def test_validate_dict_invalid(self, validator): + d = validator.validate_dict("def (") + assert d["status"] == "error" + assert d["is_valid"] is False + assert len(d["errors"]) > 0 + + # -- get_syntax_errors -- + + def test_get_syntax_errors_none(self, validator): + assert validator.get_syntax_errors("x = 1") == [] + + def test_get_syntax_errors_returns_syntax_error(self, validator): + errors = validator.get_syntax_errors("def (") + assert len(errors) == 1 + assert isinstance(errors[0], SyntaxError) + + # -- check_indentation -- + + def test_indentation_clean(self, validator): + code = "def f():\n pass\n" + assert validator.check_indentation(code) == [] + + def test_indentation_mixed_tabs_spaces(self, validator): + code = "def f():\n \tpass\n" + warnings = validator.check_indentation(code) + assert any("Mixed tabs and spaces" in w for w in warnings) + + def test_indentation_non_standard(self, validator): + code = "def f():\n pass\n" + warnings = validator.check_indentation(code) + assert any("Non-standard indentation" in w for w in warnings) + + # -- validate_imports -- + + def test_validate_imports_clean(self, validator): + code = "import os\nimport sys\n" + assert validator.validate_imports(code) == [] + + def test_validate_imports_wildcard(self, validator): + code = "from os import *\n" + warnings = validator.validate_imports(code) + assert any("Wildcard import" in w for w in warnings) + + def test_validate_imports_duplicate(self, validator): + code = "import os\nimport os\n" + warnings = validator.validate_imports(code) + assert any("Duplicate import" in w for w in warnings) + + # -- check_line_length -- + + def test_line_length_ok(self, validator): + assert validator.check_line_length("x = 1") == [] + + def test_line_length_exceeded(self, validator): + long_line = "x = " + "a" * 90 + warnings = validator.check_line_length(long_line, max_length=88) + assert len(warnings) == 1 + assert "Line too long" in warnings[0] + + +# =================================================================== +# AntipatternChecker +# =================================================================== + + +class TestAntipatternChecker: + """AntipatternChecker.check / check_dict / naming / complexity.""" + + @pytest.fixture() + def checker(self): + return AntipatternChecker() + + # -- check -- + + def test_clean_code(self, checker): + code = textwrap.dedent("""\ + def greet(name): + print(f"hi {name}") + """) + result = checker.check(Path("clean.py"), code) + assert result["errors"] == [] + assert result["warnings"] == [] + + def test_excessive_function_name(self, checker): + name = "a" * 81 + code = f"def {name}():\n pass\n" + result = checker.check(Path("long.py"), code) + assert any("chars" in e for e in result["errors"]) + + def test_combinatorial_naming(self, checker): + code = "def get_and_process_and_validate_and_transform():\n pass\n" + result = checker.check(Path("combo.py"), code) + assert any("Combinatorial" in e for e in result["errors"]) + + def test_excessive_parameters(self, checker): + params = ", ".join(f"p{i}" for i in range(8)) + code = f"def func({params}):\n pass\n" + result = checker.check(Path("params.py"), code) + assert any("parameters" in w for w in result["warnings"]) + + def test_long_function_warns(self, checker): + body = "\n".join(f" x{i} = {i}" for i in range(55)) + code = f"def long_func():\n{body}\n" + result = checker.check(Path("long_func.py"), code) + assert any("lines long" in w for w in result["warnings"]) + + def test_duplicate_class_definitions(self, checker): + code = "class Foo:\n pass\nclass Foo:\n pass\n" + result = checker.check(Path("dup.py"), code) + assert any("Duplicate class" in e for e in result["errors"]) + + def test_excessive_file_length(self, checker): + code = "\n".join(f"x{i} = {i}" for i in range(1010)) + result = checker.check(Path("big.py"), code) + assert any("lines" in w and "splitting" in w for w in result["warnings"]) + + def test_syntax_error_ignored(self, checker): + result = checker.check(Path("bad.py"), "def (") + assert result["errors"] == [] + assert result["warnings"] == [] + + # -- check_dict -- + + def test_check_dict_delegates(self, checker): + code = "def greet():\n pass\n" + result = checker.check_dict(code) + assert "errors" in result + assert "warnings" in result + + # -- check_naming_patterns -- + + def test_naming_long_function(self, checker): + code = f"def {'a' * 45}():\n pass\n" + tree = ast.parse(code) + issues = checker.check_naming_patterns(tree) + assert any("long name" in i for i in issues) + + def test_naming_too_many_underscores(self, checker): + code = "def a_b_c_d_e_f_g():\n pass\n" + tree = ast.parse(code) + issues = checker.check_naming_patterns(tree) + assert any("underscores" in i for i in issues) + + def test_class_name_lowercase(self, checker): + code = "class myclass:\n pass\n" + tree = ast.parse(code) + issues = checker.check_naming_patterns(tree) + assert any("uppercase" in i for i in issues) + + def test_class_name_too_long(self, checker): + code = f"class {'A' * 35}:\n pass\n" + tree = ast.parse(code) + issues = checker.check_naming_patterns(tree) + assert any("long name" in i for i in issues) + + # -- check_function_complexity -- + + def test_deep_nesting(self, checker): + code = textwrap.dedent("""\ + def deep(): + if True: + for i in range(1): + while True: + if True: + with open("x"): + pass + """) + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + issues = checker.check_function_complexity(func) + assert any("nesting" in i for i in issues) + + def test_many_branches(self, checker): + ifs = "\n".join(f" if x == {i}: pass" for i in range(12)) + code = f"def branchy(x):\n{ifs}\n" + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + issues = checker.check_function_complexity(func) + assert any("branches" in i for i in issues) + + def test_many_loops(self, checker): + loops = "\n".join(f" for i{n} in range(1): pass" for n in range(5)) + code = f"def loopy():\n{loops}\n" + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + issues = checker.check_function_complexity(func) + assert any("loops" in i for i in issues) + + def test_simple_function_no_issues(self, checker): + code = "def ok(x):\n return x + 1\n" + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + assert checker.check_function_complexity(func) == [] + + +# =================================================================== +# ASTAnalyzer +# =================================================================== + + +class TestASTAnalyzer: + """ASTAnalyzer.parse_code, extract_*, get_docstring.""" + + @pytest.fixture() + def analyzer(self): + return ASTAnalyzer() + + # -- parse_code -- + + def test_parse_valid_code(self, analyzer): + code = textwrap.dedent("""\ + import os + + X = 42 + + def greet(name: str) -> str: + \"\"\"Say hello.\"\"\" + return f"hi {name}" + + class Foo: + \"\"\"A class.\"\"\" + pass + """) + parsed = analyzer.parse_code(code) + assert parsed.is_valid is True + assert parsed.errors == [] + + names = {s.name for s in parsed.symbols} + assert "greet" in names + assert "Foo" in names + assert "os" in names + assert "X" in names + + def test_parse_invalid_code(self, analyzer): + parsed = analyzer.parse_code("def (") + assert parsed.is_valid is False + assert len(parsed.errors) > 0 + + def test_parse_extracts_imports(self, analyzer): + code = "import os\nfrom pathlib import Path\n" + parsed = analyzer.parse_code(code) + assert "import os" in parsed.imports + assert "from pathlib import Path" in parsed.imports + + def test_function_signature_with_types(self, analyzer): + code = "def add(a: int, b: int) -> int:\n return a + b\n" + parsed = analyzer.parse_code(code) + func_sym = [s for s in parsed.symbols if s.name == "add"][0] + assert "a: int" in func_sym.signature + assert "-> int" in func_sym.signature + + def test_function_signature_varargs(self, analyzer): + code = "def f(*args, **kwargs):\n pass\n" + parsed = analyzer.parse_code(code) + func_sym = [s for s in parsed.symbols if s.name == "f"][0] + assert "*args" in func_sym.signature + assert "**kwargs" in func_sym.signature + + def test_async_function_detected(self, analyzer): + code = "async def fetch():\n pass\n" + parsed = analyzer.parse_code(code) + names = {s.name for s in parsed.symbols if s.type == "function"} + assert "fetch" in names + + def test_class_docstring_extracted(self, analyzer): + code = 'class Foo:\n """Foo docs."""\n pass\n' + parsed = analyzer.parse_code(code) + cls = [s for s in parsed.symbols if s.name == "Foo"][0] + assert cls.docstring == "Foo docs." + + def test_module_level_variable(self, analyzer): + code = "MY_CONST = 42\n" + parsed = analyzer.parse_code(code) + var = [s for s in parsed.symbols if s.name == "MY_CONST"] + assert len(var) == 1 + assert var[0].type == "variable" + + # -- extract_functions / extract_classes -- + + def test_extract_functions(self, analyzer): + code = "def a():\n pass\ndef b():\n pass\n" + tree = ast.parse(code) + funcs = analyzer.extract_functions(tree) + assert len(funcs) == 2 + + def test_extract_classes(self, analyzer): + code = "class A:\n pass\nclass B:\n pass\n" + tree = ast.parse(code) + classes = analyzer.extract_classes(tree) + assert len(classes) == 2 + + # -- get_docstring -- + + def test_get_docstring_from_function(self, analyzer): + code = 'def f():\n """Hello."""\n pass\n' + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + assert analyzer.get_docstring(func) == "Hello." + + def test_get_docstring_none(self, analyzer): + code = "def f():\n pass\n" + tree = ast.parse(code) + func = [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)][0] + assert analyzer.get_docstring(func) is None + + +# =================================================================== +# RequirementsValidator +# =================================================================== + + +class TestRequirementsValidator: + """RequirementsValidator.validate, check_package_validity, suggest_common_packages.""" + + @pytest.fixture() + def validator(self): + return RequirementsValidator() + + # -- validate -- + + def test_valid_requirements(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("flask==2.3.0\nrequests>=2.28\n") + result = validator.validate(req) + assert result["is_valid"] is True + assert result["packages"] == 2 + assert result["errors"] == [] + + def test_hallucinated_package_detected(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("flask-graphql-a-b-c-d-e\n") + result = validator.validate(req) + assert result["is_valid"] is False + assert any("Hallucinated" in e for e in result["errors"]) + + def test_recursive_ibm_pattern(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("some-ibm-cloud-ibm-cloud-sdk\n") + result = validator.validate(req) + assert result["is_valid"] is False + + def test_package_name_too_long(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("a" * 65 + "\n") + result = validator.validate(req) + assert result["is_valid"] is False + assert any("too long" in e for e in result["errors"]) + + def test_duplicate_package_warning(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("flask\nflask\n") + result = validator.validate(req) + assert any("Duplicate" in w for w in result["warnings"]) + + def test_comment_lines_ignored(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("# comment\nflask\n") + result = validator.validate(req) + assert result["is_valid"] is True + assert result["packages"] == 1 + + def test_empty_lines_ignored(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("flask\n\nrequests\n") + result = validator.validate(req) + assert result["packages"] == 2 + + def test_many_packages_warning(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + lines = [f"pkg{i}" for i in range(35)] + req.write_text("\n".join(lines)) + result = validator.validate(req) + assert any("Many packages" in w for w in result["warnings"]) + + def test_too_many_packages_error(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + lines = [f"pkg{i}" for i in range(55)] + req.write_text("\n".join(lines)) + result = validator.validate(req) + assert any("Too many" in e for e in result["errors"]) + + def test_auto_fix_removes_bad_packages(self, validator, tmp_path): + req = tmp_path / "requirements.txt" + req.write_text("flask\nflask-graphql-a-b-c-d-e\nrequests\n") + result = validator.validate(req, fix=True) + assert result["fixed_content"] is not None + assert "flask-graphql" not in result["fixed_content"] + assert "flask" in result["fixed_content"] + assert "requests" in result["fixed_content"] + + # -- check_package_validity -- + + def test_valid_package_name(self, validator): + assert validator.check_package_validity("flask") is True + assert validator.check_package_validity("scikit-learn") is True + assert validator.check_package_validity("python-dotenv") is True + + def test_hallucinated_package_invalid(self, validator): + assert validator.check_package_validity("x-ibm-cloud-ibm-cloud-y") is False + + def test_too_long_package_invalid(self, validator): + assert validator.check_package_validity("a" * 61) is False + + def test_invalid_chars_package(self, validator): + assert validator.check_package_validity("flask@latest") is False + + def test_package_starting_with_hyphen(self, validator): + assert validator.check_package_validity("-flask") is False + + # -- suggest_common_packages -- + + def test_suggest_web(self, validator): + pkgs = validator.suggest_common_packages("web") + assert "flask" in pkgs + assert "django" in pkgs + + def test_suggest_ml(self, validator): + pkgs = validator.suggest_common_packages("ml") + assert "torch" in pkgs + + def test_suggest_unknown_falls_back(self, validator): + pkgs = validator.suggest_common_packages("gaming") + assert pkgs == validator.suggest_common_packages("general") diff --git a/tests/unit/agents/test_routing_agent.py b/tests/unit/agents/test_routing_agent.py new file mode 100644 index 000000000..6d4172342 --- /dev/null +++ b/tests/unit/agents/test_routing_agent.py @@ -0,0 +1,501 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Unit tests for RoutingAgent — routing logic, disambiguation, and agent creation.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def mock_llm_client(): + """Return a mocked LLM client that responds with valid routing JSON.""" + client = MagicMock() + client.generate.return_value = json.dumps( + { + "agent": "code", + "parameters": {"language": "typescript", "project_type": "fullstack"}, + "confidence": 0.95, + "reasoning": "Next.js detected", + } + ) + return client + + +@pytest.fixture() +def _patch_create_client(mock_llm_client): + """Patch create_client so RoutingAgent.__init__ uses the mock LLM.""" + with patch("gaia.agents.routing.agent.create_client", return_value=mock_llm_client): + yield + + +@pytest.fixture() +def _patch_code_agent(): + """Patch CodeAgent so _create_agent never touches the real agent stack.""" + with patch("gaia.agents.code.agent.CodeAgent") as cls: + cls.return_value = MagicMock() + yield cls + + +@pytest.fixture() +def router(_patch_create_client): + """Return a RoutingAgent wired to the mock LLM.""" + from gaia.agents.routing.agent import RoutingAgent + + return RoutingAgent(api_mode=True) + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +class TestRoutingAgentInit: + """Constructor and configuration.""" + + def test_import_and_exposes_process_query(self): + from gaia.agents.routing.agent import RoutingAgent + + assert hasattr(RoutingAgent, "process_query") + + def test_default_routing_model(self, router): + assert router.routing_model == "Qwen3.5-35B-A3B-GGUF" + + def test_custom_routing_model_via_env(self, _patch_create_client, monkeypatch): + monkeypatch.setenv("AGENT_ROUTING_MODEL", "custom-model") + from gaia.agents.routing.agent import RoutingAgent + + r = RoutingAgent(api_mode=True) + assert r.routing_model == "custom-model" + + def test_api_mode_stored(self, _patch_create_client): + from gaia.agents.routing.agent import RoutingAgent + + r = RoutingAgent(api_mode=True) + assert r.api_mode is True + + def test_cli_mode_default(self, _patch_create_client): + from gaia.agents.routing.agent import RoutingAgent + + r = RoutingAgent() + assert r.api_mode is False + + def test_agent_kwargs_stored(self, _patch_create_client): + from gaia.agents.routing.agent import RoutingAgent + + r = RoutingAgent(api_mode=True, foo="bar") + assert r.agent_kwargs["foo"] == "bar" + + +# --------------------------------------------------------------------------- +# LLM analysis +# --------------------------------------------------------------------------- + + +class TestAnalyzeWithLLM: + """_analyze_with_llm parses LLM JSON correctly.""" + + def test_parses_clean_json(self, router, mock_llm_client): + result = router._analyze_with_llm( + [{"role": "user", "content": "Create a Next.js app"}] + ) + assert result["agent"] == "code" + assert result["parameters"]["language"] == "typescript" + assert result["confidence"] == 0.95 + + def test_parses_json_in_markdown_code_block(self, router, mock_llm_client): + mock_llm_client.generate.return_value = ( + '```json\n{"agent":"code","parameters":{"language":"python",' + '"project_type":"api"},"confidence":0.9,"reasoning":"Flask"}\n```' + ) + result = router._analyze_with_llm( + [{"role": "user", "content": "Build a Flask API"}] + ) + assert result["parameters"]["language"] == "python" + + def test_parses_json_in_generic_code_block(self, router, mock_llm_client): + mock_llm_client.generate.return_value = ( + '```\n{"agent":"code","parameters":{"language":"python",' + '"project_type":"script"},"confidence":0.8,"reasoning":"generic"}\n```' + ) + result = router._analyze_with_llm( + [{"role": "user", "content": "Write a script"}] + ) + assert result["parameters"]["project_type"] == "script" + + def test_json_parse_failure_returns_fallback(self, router, mock_llm_client): + mock_llm_client.generate.return_value = "NOT VALID JSON AT ALL" + result = router._analyze_with_llm([{"role": "user", "content": "whatever"}]) + assert result["confidence"] == 0.0 + assert result["parameters"]["language"] == "unknown" + + def test_llm_exception_propagates(self, router, mock_llm_client): + mock_llm_client.generate.side_effect = ConnectionError("offline") + with pytest.raises(RuntimeError, match="Failed to analyze query"): + router._analyze_with_llm([{"role": "user", "content": "x"}]) + + +# --------------------------------------------------------------------------- +# has_unknowns +# --------------------------------------------------------------------------- + + +class TestHasUnknowns: + """_has_unknowns detects missing or low-confidence parameters.""" + + def test_no_unknowns_high_confidence(self, router): + analysis = { + "parameters": {"language": "typescript", "project_type": "fullstack"}, + "confidence": 0.95, + } + assert router._has_unknowns(analysis) is False + + def test_unknown_language(self, router): + analysis = { + "parameters": {"language": "unknown", "project_type": "fullstack"}, + "confidence": 0.95, + } + assert router._has_unknowns(analysis) is True + + def test_unknown_project_type(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "unknown"}, + "confidence": 0.95, + } + assert router._has_unknowns(analysis) is True + + def test_low_confidence_triggers_unknowns(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "api"}, + "confidence": 0.5, + } + assert router._has_unknowns(analysis) is True + + def test_boundary_confidence_0_9_not_unknown(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "api"}, + "confidence": 0.9, + } + assert router._has_unknowns(analysis) is False + + def test_boundary_confidence_0_89_is_unknown(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "api"}, + "confidence": 0.89, + } + assert router._has_unknowns(analysis) is True + + +# --------------------------------------------------------------------------- +# Clarification questions +# --------------------------------------------------------------------------- + + +class TestClarificationQuestions: + """_generate_clarification_question returns context-appropriate prompts.""" + + def test_both_unknown(self, router): + analysis = { + "parameters": {"language": "unknown", "project_type": "unknown"}, + } + q = router._generate_clarification_question(analysis) + assert "What kind of application" in q + + def test_language_unknown_fullstack(self, router): + analysis = { + "parameters": {"language": "unknown", "project_type": "fullstack"}, + } + q = router._generate_clarification_question(analysis) + assert "language" in q.lower() or "framework" in q.lower() + + def test_language_unknown_script(self, router): + analysis = { + "parameters": {"language": "unknown", "project_type": "script"}, + } + q = router._generate_clarification_question(analysis) + assert "language" in q.lower() + + def test_project_type_unknown_typescript(self, router): + analysis = { + "parameters": {"language": "typescript", "project_type": "unknown"}, + } + q = router._generate_clarification_question(analysis) + assert "TypeScript" in q + + def test_project_type_unknown_python(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "unknown"}, + } + q = router._generate_clarification_question(analysis) + assert "Python" in q + + +# --------------------------------------------------------------------------- +# Keyword fallback detection +# --------------------------------------------------------------------------- + + +class TestFallbackKeywordDetection: + """_fallback_keyword_detection finds language from framework keywords.""" + + @pytest.mark.parametrize( + "query, expected_lang", + [ + ("Build a Next.js blog", "typescript"), + ("Create a React dashboard", "typescript"), + ("Express REST API", "typescript"), + ("Angular admin panel", "typescript"), + ("Svelte app", "typescript"), + ], + ) + def test_typescript_keywords(self, router, query, expected_lang): + result = router._fallback_keyword_detection(query) + assert result["parameters"]["language"] == expected_lang + + @pytest.mark.parametrize( + "query, expected_lang", + [ + ("Django REST API", "python"), + ("Flask microservice", "python"), + ("FastAPI server", "python"), + ("Pandas data analysis", "python"), + ], + ) + def test_python_keywords(self, router, query, expected_lang): + result = router._fallback_keyword_detection(query) + assert result["parameters"]["language"] == expected_lang + + def test_no_keywords_returns_unknown(self, router): + result = router._fallback_keyword_detection("Build something cool") + assert result["parameters"]["language"] == "unknown" + + def test_typescript_cli_is_script(self, router): + result = router._fallback_keyword_detection("Node.js CLI tool") + assert result["parameters"]["project_type"] == "script" + + def test_python_api_project_type(self, router): + result = router._fallback_keyword_detection("FastAPI REST backend") + assert result["parameters"]["project_type"] == "api" + + +# --------------------------------------------------------------------------- +# Default-to-TypeScript logic +# --------------------------------------------------------------------------- + + +class TestDefaultUnknownLanguage: + """_default_unknown_language_to_typescript fills unknowns.""" + + def test_unknown_language_becomes_typescript(self, router): + analysis = { + "parameters": {"language": "unknown", "project_type": "unknown"}, + "confidence": 0.5, + "reasoning": "ambiguous", + } + result = router._default_unknown_language_to_typescript(analysis) + assert result["parameters"]["language"] == "typescript" + assert result["parameters"]["project_type"] == "fullstack" + assert result["confidence"] == 1.0 + + def test_known_language_unchanged(self, router): + analysis = { + "parameters": {"language": "python", "project_type": "api"}, + "confidence": 0.9, + "reasoning": "clear", + } + result = router._default_unknown_language_to_typescript(analysis) + assert result["parameters"]["language"] == "python" + assert result["confidence"] == 0.9 + + def test_prints_info_when_defaulting(self, router): + with patch.object(router, "_get_console") as mock_get: + mock_console = MagicMock() + mock_get.return_value = mock_console + analysis = { + "parameters": {"language": "unknown", "project_type": "unknown"}, + "confidence": 0.5, + "reasoning": "", + } + router._default_unknown_language_to_typescript(analysis) + mock_console.print_info.assert_called_once() + + +# --------------------------------------------------------------------------- +# enforce_typescript_only +# --------------------------------------------------------------------------- + + +class TestEnforceTypescriptOnly: + """_enforce_typescript_only rejects non-TS languages.""" + + def test_typescript_fullstack_passes(self, router): + console = MagicMock() + lang, pt = router._enforce_typescript_only("typescript", "fullstack", console) + assert lang == "typescript" + assert pt == "fullstack" + + def test_python_raises_system_exit(self, router): + console = MagicMock() + with pytest.raises(SystemExit): + router._enforce_typescript_only("python", "script", console) + console.print_error.assert_called_once() + + +# --------------------------------------------------------------------------- +# Agent creation +# --------------------------------------------------------------------------- + + +class TestCreateAgent: + """_create_agent produces a CodeAgent with correct params.""" + + def test_creates_code_agent(self, router, _patch_code_agent): + analysis = { + "agent": "code", + "parameters": {"language": "typescript", "project_type": "fullstack"}, + } + agent = router._create_agent(analysis) + assert agent is not None + _patch_code_agent.assert_called_once() + call_kwargs = _patch_code_agent.call_args + assert call_kwargs.kwargs.get("language") == "typescript" + + def test_unknown_agent_type_raises(self, router): + analysis = { + "agent": "unknown_agent", + "parameters": {}, + } + with pytest.raises(ValueError, match="Unknown agent type"): + router._create_agent(analysis) + + +class TestCreateAgentWithDefaults: + """_create_agent_with_defaults fills unknowns before creating.""" + + def test_unknown_typescript_defaults_to_fullstack(self, router, _patch_code_agent): + analysis = { + "parameters": {"language": "typescript", "project_type": "unknown"}, + } + router._create_agent_with_defaults(analysis) + call_kwargs = _patch_code_agent.call_args + assert call_kwargs.kwargs.get("project_type") == "fullstack" + + +# --------------------------------------------------------------------------- +# process_query end-to-end (API mode) +# --------------------------------------------------------------------------- + + +class TestProcessQueryAPIMode: + """process_query in API mode auto-executes the agent.""" + + def test_resolved_query_executes(self, router, mock_llm_client, _patch_code_agent): + mock_agent = _patch_code_agent.return_value + mock_agent.process_query.return_value = "done" + + result = router.process_query("Create a Next.js blog") + assert result == "done" + mock_agent.process_query.assert_called_once() + + def test_execute_false_returns_agent( + self, router, mock_llm_client, _patch_code_agent + ): + agent = router.process_query("Create a Next.js blog", execute=False) + assert agent is _patch_code_agent.return_value + + def test_low_confidence_api_mode_uses_defaults( + self, router, mock_llm_client, _patch_code_agent + ): + mock_llm_client.generate.return_value = json.dumps( + { + "agent": "code", + "parameters": {"language": "typescript", "project_type": "unknown"}, + "confidence": 0.4, + "reasoning": "unclear", + } + ) + mock_agent = _patch_code_agent.return_value + mock_agent.process_query.return_value = "ok" + result = router.process_query("Build something") + assert result == "ok" + + +# --------------------------------------------------------------------------- +# process_query — CLI (interactive) mode +# --------------------------------------------------------------------------- + + +class TestProcessQueryCLIMode: + """process_query in CLI mode asks for clarification via input().""" + + def test_cli_mode_asks_clarification_then_resolves( + self, _patch_create_client, mock_llm_client, _patch_code_agent + ): + from gaia.agents.routing.agent import RoutingAgent + + router = RoutingAgent(api_mode=False) + + # First response: known language but unknown project_type + low + # confidence. _default_unknown_language_to_typescript leaves this + # unchanged (language != "unknown"), so _has_unknowns fires and the + # clarification branch runs. + # Second response (after user answers): fully resolved. + mock_llm_client.generate.side_effect = [ + json.dumps( + { + "agent": "code", + "parameters": { + "language": "typescript", + "project_type": "unknown", + }, + "confidence": 0.4, + "reasoning": "ambiguous", + } + ), + json.dumps( + { + "agent": "code", + "parameters": { + "language": "typescript", + "project_type": "fullstack", + }, + "confidence": 0.95, + "reasoning": "user clarified", + } + ), + ] + + with patch("builtins.input", return_value="Next.js blog") as mock_input: + agent = router.process_query("Build something", execute=False) + + mock_input.assert_called_once() + assert agent is _patch_code_agent.return_value + + def test_cli_mode_empty_response_uses_defaults( + self, _patch_create_client, mock_llm_client, _patch_code_agent + ): + from gaia.agents.routing.agent import RoutingAgent + + router = RoutingAgent(api_mode=False) + + mock_llm_client.generate.return_value = json.dumps( + { + "agent": "code", + "parameters": { + "language": "typescript", + "project_type": "unknown", + }, + "confidence": 0.4, + "reasoning": "unclear", + } + ) + + with patch("builtins.input", return_value=""): + agent = router.process_query("Build something", execute=False) + + assert agent is _patch_code_agent.return_value diff --git a/tests/unit/test_openai_provider.py b/tests/unit/test_openai_provider.py new file mode 100644 index 000000000..a0b8c290b --- /dev/null +++ b/tests/unit/test_openai_provider.py @@ -0,0 +1,326 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Unit tests for OpenAIProvider — chat, generate, embed, stream, errors.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from gaia.llm.exceptions import NotSupportedError + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def mock_openai_module(): + """Patch the openai module so OpenAIProvider never hits the network.""" + mock_mod = MagicMock() + mock_client_instance = MagicMock() + mock_mod.OpenAI.return_value = mock_client_instance + with patch.dict("sys.modules", {"openai": mock_mod}): + yield mock_mod, mock_client_instance + + +@pytest.fixture() +def provider(mock_openai_module): + """Return an OpenAIProvider backed by the mocked openai module.""" + from gaia.llm.providers.openai_provider import OpenAIProvider + + return OpenAIProvider(api_key="sk-test", model="gpt-4o") + + +@pytest.fixture() +def client(mock_openai_module): + """Shortcut to the mocked openai.OpenAI() instance.""" + return mock_openai_module[1] + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +class TestOpenAIProviderInit: + """Constructor and basic properties.""" + + def test_provider_name(self, provider): + assert provider.provider_name == "OpenAI" + + def test_default_model(self, mock_openai_module): + from gaia.llm.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="sk-test") + assert p._model == "gpt-4o" + + def test_custom_model(self, mock_openai_module): + from gaia.llm.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="sk-test", model="gpt-4-turbo") + assert p._model == "gpt-4-turbo" + + def test_system_prompt_stored(self, mock_openai_module): + from gaia.llm.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="sk-test", system_prompt="You are helpful.") + assert p._system_prompt == "You are helpful." + + def test_extra_kwargs_not_passed_to_openai_client(self, mock_openai_module): + from gaia.llm.providers.openai_provider import OpenAIProvider + + mock_mod, _ = mock_openai_module + OpenAIProvider(api_key="sk-test", base_url="http://x", unknown_arg=True) + mock_mod.OpenAI.assert_called_once_with(api_key="sk-test") + + +# --------------------------------------------------------------------------- +# chat() — non-streaming +# --------------------------------------------------------------------------- + + +class TestChat: + """chat() delegates to OpenAI SDK and returns content.""" + + def test_returns_message_content(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Hello!" + client.chat.completions.create.return_value = mock_response + + result = provider.chat([{"role": "user", "content": "Hi"}], stream=False) + assert result == "Hello!" + + def test_uses_default_model(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + client.chat.completions.create.return_value = mock_response + + provider.chat([{"role": "user", "content": "Hi"}]) + call_kwargs = client.chat.completions.create.call_args + assert call_kwargs.kwargs["model"] == "gpt-4o" + + def test_model_override(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + client.chat.completions.create.return_value = mock_response + + provider.chat([{"role": "user", "content": "Hi"}], model="gpt-4-turbo") + call_kwargs = client.chat.completions.create.call_args + assert call_kwargs.kwargs["model"] == "gpt-4-turbo" + + def test_system_prompt_prepended(self, mock_openai_module): + from gaia.llm.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="sk-test", system_prompt="Be concise.") + _, mock_client = mock_openai_module + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + mock_client.chat.completions.create.return_value = mock_response + + p.chat([{"role": "user", "content": "Hi"}]) + call_kwargs = mock_client.chat.completions.create.call_args + messages = call_kwargs.kwargs["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Be concise." + assert messages[1]["role"] == "user" + + def test_no_system_prompt_by_default(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + client.chat.completions.create.return_value = mock_response + + provider.chat([{"role": "user", "content": "Hi"}]) + call_kwargs = client.chat.completions.create.call_args + messages = call_kwargs.kwargs["messages"] + assert messages[0]["role"] == "user" + + def test_extra_kwargs_passed_through(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + client.chat.completions.create.return_value = mock_response + + provider.chat( + [{"role": "user", "content": "Hi"}], + temperature=0.5, + max_tokens=100, + ) + call_kwargs = client.chat.completions.create.call_args + assert call_kwargs.kwargs["temperature"] == 0.5 + assert call_kwargs.kwargs["max_tokens"] == 100 + + def test_sdk_exception_propagates(self, provider, client): + client.chat.completions.create.side_effect = Exception("rate limited") + with pytest.raises(Exception, match="rate limited"): + provider.chat([{"role": "user", "content": "Hi"}]) + + +# --------------------------------------------------------------------------- +# chat() — streaming +# --------------------------------------------------------------------------- + + +class TestChatStreaming: + """chat(stream=True) returns an iterator of text chunks.""" + + def _make_chunk(self, content): + chunk = MagicMock() + chunk.choices = [MagicMock()] + chunk.choices[0].delta.content = content + return chunk + + def _make_empty_chunk(self): + chunk = MagicMock() + chunk.choices = [MagicMock()] + chunk.choices[0].delta.content = None + return chunk + + def _make_no_choices_chunk(self): + chunk = MagicMock() + chunk.choices = [] + return chunk + + def test_stream_yields_content(self, provider, client): + chunks = [self._make_chunk("Hello"), self._make_chunk(" world")] + client.chat.completions.create.return_value = iter(chunks) + + result = provider.chat([{"role": "user", "content": "Hi"}], stream=True) + pieces = list(result) + assert pieces == ["Hello", " world"] + + def test_stream_skips_empty_deltas(self, provider, client): + chunks = [ + self._make_chunk("A"), + self._make_empty_chunk(), + self._make_chunk("B"), + ] + client.chat.completions.create.return_value = iter(chunks) + + pieces = list(provider.chat([{"role": "user", "content": "Hi"}], stream=True)) + assert pieces == ["A", "B"] + + def test_stream_skips_no_choices(self, provider, client): + chunks = [ + self._make_chunk("X"), + self._make_no_choices_chunk(), + self._make_chunk("Y"), + ] + client.chat.completions.create.return_value = iter(chunks) + + pieces = list(provider.chat([{"role": "user", "content": "Hi"}], stream=True)) + assert pieces == ["X", "Y"] + + +# --------------------------------------------------------------------------- +# generate() +# --------------------------------------------------------------------------- + + +class TestGenerate: + """generate() wraps prompt into a user message and delegates to chat().""" + + def test_generate_non_streaming(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "42" + client.chat.completions.create.return_value = mock_response + + result = provider.generate("What is 6*7?") + assert result == "42" + + call_kwargs = client.chat.completions.create.call_args + messages = call_kwargs.kwargs["messages"] + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == "What is 6*7?" + + def test_generate_with_model_override(self, provider, client): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + client.chat.completions.create.return_value = mock_response + + provider.generate("test", model="gpt-3.5-turbo") + call_kwargs = client.chat.completions.create.call_args + assert call_kwargs.kwargs["model"] == "gpt-3.5-turbo" + + def test_generate_streaming(self, provider, client): + chunk = MagicMock() + chunk.choices = [MagicMock()] + chunk.choices[0].delta.content = "streamed" + client.chat.completions.create.return_value = iter([chunk]) + + result = provider.generate("test", stream=True) + assert list(result) == ["streamed"] + + +# --------------------------------------------------------------------------- +# embed() +# --------------------------------------------------------------------------- + + +class TestEmbed: + """embed() returns a list of embedding vectors.""" + + def test_embed_returns_vectors(self, provider, client): + item1 = MagicMock() + item1.embedding = [0.1, 0.2, 0.3] + item2 = MagicMock() + item2.embedding = [0.4, 0.5, 0.6] + mock_response = MagicMock() + mock_response.data = [item1, item2] + client.embeddings.create.return_value = mock_response + + result = provider.embed(["hello", "world"]) + assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + call_kwargs = client.embeddings.create.call_args + assert call_kwargs.kwargs["input"] == ["hello", "world"] + + def test_embed_uses_default_model(self, provider, client): + mock_response = MagicMock() + mock_response.data = [] + client.embeddings.create.return_value = mock_response + + provider.embed(["text"]) + call_kwargs = client.embeddings.create.call_args + assert call_kwargs.kwargs["model"] == "text-embedding-3-small" + + def test_embed_custom_model(self, provider, client): + mock_response = MagicMock() + mock_response.data = [] + client.embeddings.create.return_value = mock_response + + provider.embed(["text"], model="text-embedding-ada-002") + call_kwargs = client.embeddings.create.call_args + assert call_kwargs.kwargs["model"] == "text-embedding-ada-002" + + +# --------------------------------------------------------------------------- +# Unsupported methods (inherited from LLMClient) +# --------------------------------------------------------------------------- + + +class TestUnsupportedMethods: + """Methods not implemented by OpenAIProvider raise NotSupportedError.""" + + def test_vision_not_supported(self, provider): + with pytest.raises(NotSupportedError, match="OpenAI.*vision"): + provider.vision([b"img"], "describe") + + def test_get_performance_stats_not_supported(self, provider): + with pytest.raises(NotSupportedError, match="OpenAI.*get_performance_stats"): + provider.get_performance_stats() + + def test_load_model_not_supported(self, provider): + with pytest.raises(NotSupportedError, match="OpenAI.*load_model"): + provider.load_model("gpt-4o") + + def test_unload_model_not_supported(self, provider): + with pytest.raises(NotSupportedError, match="OpenAI.*unload_model"): + provider.unload_model()