diff --git a/docs/connectors/tavily.mdx b/docs/connectors/tavily.mdx new file mode 100644 index 000000000..4f613839f --- /dev/null +++ b/docs/connectors/tavily.mdx @@ -0,0 +1,107 @@ +--- +title: "Tavily" +icon: "globe" +description: "Give GAIA agents web search and content extraction via Tavily." +--- + + + **Connector ID:** `mcp-tavily` · **Type:** `mcp_server` · **Catalog entry:** [`src/gaia/connectors/catalog/mcp_servers.py`](https://github.com/amd/gaia/blob/main/src/gaia/connectors/catalog/mcp_servers.py) + + +## What you'll need + +[Tavily](https://tavily.com) is a web-search API built for AI agents. The +connector is an **MCP server** — GAIA spawns the +[`tavily-mcp`](https://github.com/tavily-ai/tavily-mcp) process on demand via +`npx` and routes tool calls (`tavily-search`, `tavily-extract`) through it, so +the tools become available to **all** GAIA agents. + +It needs a single secret: a **Tavily API key**. You'll create one, paste it +into GAIA once, and you're done. The key lives encrypted in your OS keyring; +the MCP server reads it via a `$keyring` reference at launch. + +## Step 1 — Get an API key + +1. Sign in at app.tavily.com. +2. Copy your API key from the dashboard. It starts with `tvly-` followed by a + string of characters (e.g. `tvly-AbCd…`). If your key doesn't start with + `tvly-`, you're looking at the wrong value. + +The free tier includes a monthly credit allowance; a basic search costs 1 +credit and an advanced search costs 2. + +## Step 2 — Configure GAIA + +**From the CLI:** + +```bash +gaia connectors configure mcp-tavily --set TAVILY_API_KEY=tvly-... +``` + +**From the Agent UI:** + +1. Launch the Agent UI: `gaia chat --ui`. +2. **Settings** (gear) → **Connections** → click the **Tavily** tile. +3. Paste the key into the **Tavily API Key** field and click **Save**. + +Either path stores the key in your OS keyring (a single slot, distinct from +other connectors) and writes a `$keyring` reference into +`~/.gaia/mcp_servers.json` — the key never lives in plaintext on disk. + +## Step 3 — Use it + +Once configured, the `tavily-search` / `tavily-extract` MCP tools are available +to any agent you grant them to: + +```bash +gaia connectors grants grant mcp-tavily builtin:chat --scopes "*" +``` + +GAIA also ships a Python wrapper (`gaia.web.tavily`) used by web-research +workflows, with response caching, a credit budget, and a CLI: + +```bash +gaia knowledge search "AMD ROCm latest release" --max-results 5 +gaia knowledge usage # show credits spent +``` + + + If the connector isn't configured, `gaia knowledge search` and the wrapper + fall back to a keyless DuckDuckGo search — so search works out of the box, + and Tavily simply upgrades its quality and adds `extract`/`crawl`. + + +## Common issues + +### `Unauthorized` / `401` from the MCP server + +The key in your keyring is wrong or revoked. Click **Disconnect** on the tile +(or `gaia connectors disconnect mcp-tavily`) and re-add a fresh key. + +### `npx: command not found` + +`tavily-mcp` is launched via `npx`. Install Node 18+ and ensure `npx` is on +your `PATH`: + +```bash +node --version # must be >= 18 +which npx # must resolve to a real path +``` + +### Budget exceeded + +`gaia knowledge` warns (if nearing the budget) then blocks once a session passes its `--budget` credit cap. Raise the cap, or pass `--no-block` to warn and proceed instead of blocking. + +## Revoking access + +- **From GAIA:** Settings → Connections → Tavily → **Disconnect** (or + `gaia connectors disconnect mcp-tavily`). The key is removed from the keyring + and the entry is dropped from `mcp_servers.json`. +- **From Tavily:** rotate or delete the key in your + [Tavily dashboard](https://app.tavily.com/). + +## See also + +- [Connectors overview](/connectors) +- [Tavily documentation](https://docs.tavily.com/) +- [Connectors security model](/security/connections) diff --git a/docs/docs.json b/docs/docs.json index e913b6502..29ab8a364 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -89,6 +89,7 @@ "connectors/index", "connectors/google", "connectors/github", + "connectors/tavily", "security/connections" ] }, diff --git a/docs/reference/cli.mdx b/docs/reference/cli.mdx index 0f05d12bc..323a967a6 100644 --- a/docs/reference/cli.mdx +++ b/docs/reference/cli.mdx @@ -1677,6 +1677,55 @@ gaia cache clear --all --- +### Knowledge Command + +Web research via [Tavily](https://tavily.com), with SQLite result caching, a per-session +credit budget, and an automatic keyless DuckDuckGo fallback when the `mcp-tavily` +connector isn't configured. See the [Tavily connector](/connectors/tavily). + +```bash +gaia knowledge {search,extract,usage} [OPTIONS] +``` + +**Actions:** + +| Action | Description | +|--------|-------------| +| `search` | Web search for a query. Falls back to DuckDuckGo when Tavily isn't configured. | +| `extract` | Extract clean content from one or more URLs. Requires the Tavily connector. | +| `usage` | Print cached credit-usage totals (per operation + session total). | + +**Options:** + +| Flag | Type | Applies to | Description | +|------|------|------------|-------------| +| `--max-results` | integer | `search` | Maximum results to return (default: 5) | +| `--depth` | `basic` \| `advanced` | `search`, `extract` | Depth; `advanced` costs more credits (default: `basic`) | +| `--budget` | integer | `search`, `extract` | Credit cap for the session; omit for unlimited | +| `--no-block` | flag | `search`, `extract` | Warn instead of blocking when the budget cap is exceeded | + +**Examples:** + + +```bash Search the web +gaia knowledge search "AMD ROCm latest release" --max-results 5 +``` + +```bash Extract page content +gaia knowledge extract https://example.com/post +``` + +```bash Show credit usage +gaia knowledge usage +``` + + +`search` automatically degrades to DuckDuckGo when the `mcp-tavily` connector isn't +configured; `extract` requires Tavily and raises an actionable error otherwise. Configure +the connector with `gaia connectors configure mcp-tavily --set TAVILY_API_KEY=tvly-...`. + +--- + ### Kill Command Terminate processes running on specific ports. diff --git a/setup.py b/setup.py index bc147d506..20da09f74 100644 --- a/setup.py +++ b/setup.py @@ -127,6 +127,7 @@ "beautifulsoup4", "watchdog>=2.1.0", "pillow>=9.0.0", + "tavily-python>=0.5.0", ], extras_require={ "image": [ diff --git a/src/gaia/cli.py b/src/gaia/cli.py index f95fdda4d..7cad70365 100644 --- a/src/gaia/cli.py +++ b/src/gaia/cli.py @@ -1773,6 +1773,60 @@ def build_parser(): telegram_parser.set_defaults(action="telegram") + # Knowledge command — web research via the Tavily wrapper + knowledge_parser = subparsers.add_parser( + "knowledge", + help="Web research via Tavily (search|extract|usage), with caching and a credit budget", + ) + knowledge_subparsers = knowledge_parser.add_subparsers( + dest="knowledge_action", help="knowledge action to perform" + ) + + k_search = knowledge_subparsers.add_parser("search", help="Run a web search") + k_search.add_argument("query", help="Search query") + k_search.add_argument( + "--max-results", type=int, default=5, help="Max results (default: 5)" + ) + k_search.add_argument( + "--depth", + choices=("basic", "advanced"), + default="basic", + help="Search depth — advanced costs more credits (default: basic)", + ) + k_search.add_argument( + "--budget", + type=int, + default=None, + help="Credit cap for this session; omit for unlimited", + ) + k_search.add_argument( + "--no-block", + action="store_true", + help="Warn instead of blocking when the budget cap is exceeded", + ) + + k_extract = knowledge_subparsers.add_parser( + "extract", help="Extract clean content from one or more URLs (requires Tavily)" + ) + k_extract.add_argument("urls", nargs="+", help="One or more URLs to extract") + k_extract.add_argument( + "--depth", + choices=("basic", "advanced"), + default="basic", + help="Extract depth (default: basic)", + ) + k_extract.add_argument( + "--budget", type=int, default=None, help="Credit cap for this session" + ) + k_extract.add_argument( + "--no-block", + action="store_true", + help="Warn instead of blocking when the budget cap is exceeded", + ) + + knowledge_subparsers.add_parser("usage", help="Show cached credit-usage totals") + knowledge_parser.set_defaults(action="knowledge") + # Add model download command download_parser = subparsers.add_parser( "download", @@ -3812,6 +3866,11 @@ def main(): handle_cache_command(args) return + # Handle Knowledge command (Tavily web research) + if args.action == "knowledge": + handle_knowledge_command(args) + return + # Handle Memory command if args.action == "memory": handle_memory_command(args) @@ -4713,6 +4772,69 @@ def handle_blender_command(args): sys.exit(1) +def _print_knowledge_usage(client): + """Print a one-line credit-usage summary for a Tavily client.""" + usage = client.usage() + cap = usage["cap"] + cap_str = "unlimited" if cap is None else str(cap) + print(f"\n💳 Credits used: {usage['total_credits']} (cap: {cap_str})") + + +def handle_knowledge_command(args): + """Handle `gaia knowledge` — Tavily web research with caching + budget. + + Args: + args: Parsed command-line arguments + """ + action = getattr(args, "knowledge_action", None) + if action is None: + print("❌ Error: No knowledge action specified") + print("Available actions: search, extract, usage") + print("Run 'gaia knowledge --help' for more information") + return + + from gaia.web.tavily import ( + BudgetConfig, + TavilyBudgetExceeded, + TavilyClient, + TavilyConfigError, + ) + + budget = BudgetConfig( + cap=getattr(args, "budget", None), + block=not getattr(args, "no_block", False), + ) + client = TavilyClient(budget=budget) + try: + if action == "search": + result = client.search( + args.query, search_depth=args.depth, max_results=args.max_results + ) + source = result.get("source", "tavily") + print(f"\n=== Results for {args.query!r} (source: {source}) ===") + for i, r in enumerate(result.get("results", []), 1): + print(f"{i}. {r.get('title', '')}") + print(f" {r.get('url', '')}") + content = r.get("content") or r.get("snippet") or "" + if content: + print(f" {content[:200]}") + _print_knowledge_usage(client) + elif action == "extract": + result = client.extract(args.urls, extract_depth=args.depth) + print(json.dumps(result, indent=2)) + _print_knowledge_usage(client) + elif action == "usage": + _print_knowledge_usage(client) + except TavilyBudgetExceeded as e: + print(f"🛑 {e}") + sys.exit(1) + except TavilyConfigError as e: + print(f"❌ {e}") + sys.exit(1) + finally: + client.close() + + def handle_cache_command(args): """Handle the cache management command. diff --git a/src/gaia/connectors/catalog/mcp_servers.py b/src/gaia/connectors/catalog/mcp_servers.py index 0c42b66cf..6b6a5ba32 100644 --- a/src/gaia/connectors/catalog/mcp_servers.py +++ b/src/gaia/connectors/catalog/mcp_servers.py @@ -49,6 +49,30 @@ ), ) +_TAVILY = ConnectorSpec( + id="mcp-tavily", + display_name="Tavily", + icon="🛜", + category="dev-tools", + tier=1, + type="mcp_server", + description="Web search and content extraction for agents through the Tavily API.", + docs_url="https://amd-gaia.ai/docs/connectors/tavily", + mcp_command="npx", + mcp_args=("-y", "tavily-mcp@latest"), + mcp_env_keys=("TAVILY_API_KEY",), + config_schema=( + ConfigField( + key="TAVILY_API_KEY", + label="Tavily API Key", + kind="secret", + placeholder="tvly-…", + help_md="Get an API key from your [Tavily dashboard](https://app.tavily.com/).", + secret=True, + ), + ), +) + _MEMORY = ConnectorSpec( id="mcp-memory", display_name="Memory", @@ -79,6 +103,7 @@ _ALL_SPECS = ( _GITHUB, + _TAVILY, _MEMORY, _GIT, ) diff --git a/src/gaia/web/tavily.py b/src/gaia/web/tavily.py new file mode 100644 index 000000000..4ba331e93 --- /dev/null +++ b/src/gaia/web/tavily.py @@ -0,0 +1,624 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Tavily web-search wrapper with caching, credit accounting, and a fallback. + +Sits between GAIA and the ``tavily-python`` SDK so callers get one front door +with four behaviours layered on top of the raw API: + +- **SQLite cache** (``DatabaseMixin``): a normalized query/params hash maps to a + stored response with a TTL, so repeat queries don't re-spend credits. +- **Credit ledger**: every billable call records its credit cost, read from the + SDK response when present, otherwise estimated from Tavily's published pricing. +- **Budget gate**: warn once usage crosses a soft threshold, block once it + exceeds the hard cap. Blocking is the default; ``block=False`` downgrades the + cap to a warning. +- **DuckDuckGo fallback**: when the ``mcp-tavily`` connector isn't configured, + ``search`` degrades to the keyless DuckDuckGo path instead of failing. + +The API key is read from the connector's keyring entry, never passed around in +plaintext config. See ``gaia.connectors.catalog.mcp_servers._TAVILY``. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import re +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Union + +from gaia.database.mixin import DatabaseMixin +from gaia.logger import get_logger +from gaia.web.client import WebClient + +log = get_logger(__name__) + +# The ``tavily-python`` SDK is an optional dependency: an unconfigured install +# still works via the DuckDuckGo fallback. Guard the import like web/client.py +# does for beautifulsoup4. +try: + from tavily import AsyncTavilyClient as _SdkAsyncTavilyClient + from tavily import TavilyClient as _SdkTavilyClient + + TAVILY_SDK_AVAILABLE = True +except ImportError: # pragma: no cover - exercised only when SDK absent + _SdkTavilyClient = None + _SdkAsyncTavilyClient = None + TAVILY_SDK_AVAILABLE = False + +# Matches ConnectorSpec.id in the catalog and the keyring env key it stores. +_CONNECTOR_ID = "mcp-tavily" +_API_KEY_ENV = "TAVILY_API_KEY" + +_DEFAULT_DB_PATH = Path.home() / ".gaia" / "tavily_cache.db" +_DEFAULT_TTL_SECONDS = 24 * 60 * 60 # results go stale; re-fetch after a day + +# Credit cost per (operation, depth), used when the SDK response carries no +# usage metadata. Mirrors Tavily's published pricing (basic search = 1 credit, +# advanced = 2). Crawl cost is page-dependent — the table value is only the +# pre-call estimate for the budget gate; the ledger records the response value +# when the SDK provides one. +_CREDIT_COST = { + ("search", "basic"): 1, + ("search", "advanced"): 2, + ("extract", "basic"): 1, + ("extract", "advanced"): 2, + ("crawl", "basic"): 1, + ("crawl", "advanced"): 2, +} + + +class TavilyError(Exception): + """Base class for Tavily wrapper errors.""" + + +class TavilyConfigError(TavilyError): + """The Tavily connector / SDK isn't usable for the requested operation.""" + + +class TavilyBudgetExceeded(TavilyError): + """A call was blocked because it would exceed the configured credit cap.""" + + +@dataclass +class BudgetConfig: + """Credit budget for a wrapper session. + + ``cap`` is the hard limit in credits; ``None`` means unlimited (the gate is + a no-op). ``warn_threshold`` is the fraction of the cap at which a warning + is logged. ``block`` decides what happens once the cap is exceeded: raise + ``TavilyBudgetExceeded`` (default) or merely warn and proceed. + """ + + cap: Optional[int] = None + warn_threshold: float = 0.8 + block: bool = True + + +def _normalize_query(query: str) -> str: + """Lowercase, trim, and collapse whitespace so trivial variants share a key.""" + return re.sub(r"\s+", " ", query.strip().lower()) + + +def _load_api_key() -> Optional[str]: + """Return the Tavily API key from the connector keyring, or ``None``. + + ``None`` means the ``mcp-tavily`` connector isn't usable — the caller should + fall back to DuckDuckGo. Imports are deferred so that merely importing this + module doesn't pull in ``keyring``; if ``keyring`` isn't installed at all + (it lives in the ``[ui]`` extras), the connector can't have been configured, + so we treat that as "not configured" rather than crashing. + """ + try: + from gaia.connectors.handler import get_credential_sync + from gaia.connectors.mcp_server import is_mcp_server_configured + except ImportError as e: + log.info("Connector subsystem unavailable (%s); using DuckDuckGo.", e) + return None + + if not is_mcp_server_configured(_CONNECTOR_ID): + return None + cred = get_credential_sync(_CONNECTOR_ID) + return cred["env"][_API_KEY_ENV] + + +async def _load_api_key_async() -> Optional[str]: + """Async counterpart to :func:`_load_api_key`. + + Awaits ``get_credential`` instead of the sync wrapper, so it is safe to call + from inside a running event loop — where ``get_credential_sync`` raises. The + async client uses this so its constructor never blocks the loop. + """ + try: + from gaia.connectors.handler import get_credential + from gaia.connectors.mcp_server import is_mcp_server_configured + except ImportError as e: + log.info("Connector subsystem unavailable (%s); using DuckDuckGo.", e) + return None + + if not is_mcp_server_configured(_CONNECTOR_ID): + return None + cred = await get_credential(_CONNECTOR_ID) + return cred["env"][_API_KEY_ENV] + + +class _TavilyBase(DatabaseMixin): + """Shared cache, ledger, budget, and fallback logic for both clients. + + Subclasses set ``_SDK_CLASS`` and implement the public ``search`` / ``extract`` + / ``crawl`` methods (sync vs. async); everything credit- and cache-related + lives here so the two clients can't drift apart. + """ + + _SDK_CLASS: Any = None + + def __init__( + self, + *, + db_path: Union[str, Path] = _DEFAULT_DB_PATH, + budget: Optional[BudgetConfig] = None, + cache_ttl: int = _DEFAULT_TTL_SECONDS, + sdk_client: Any = None, + api_key: Optional[str] = None, + web_client: Optional[WebClient] = None, + ) -> None: + self.init_db(str(db_path)) + self._ensure_schema() + self._budget = budget or BudgetConfig() + self._cache_ttl = cache_ttl + self._web_client = web_client + self._explicit_api_key = api_key + + # Resolution order: injected client (tests) → explicit key → connector + # keyring. No key at all = unconfigured = DuckDuckGo fallback mode. + if sdk_client is not None: + self._sdk = sdk_client + self._configured = True + self._key_resolved = True + return + + self._sdk = None + self._configured = False + self._key_resolved = False + self._resolve_key_eagerly() + + def _resolve_key_eagerly(self) -> None: + """Resolve the API key during construction. + + The sync client does this safely. ``AsyncTavilyClient`` overrides it to + defer resolution to first use, because synchronous resolution calls + ``get_credential_sync()``, which raises inside a running event loop. + """ + key = ( + self._explicit_api_key + if self._explicit_api_key is not None + else _load_api_key() + ) + self._apply_key(key) + + def _apply_key(self, key: Optional[str]) -> None: + """Wire up the SDK client from a resolved key, or enter fallback mode.""" + if key is None: + self._sdk = None + self._configured = False + elif self._SDK_CLASS is None: + raise TavilyConfigError( + "The Tavily connector is configured but the 'tavily-python' SDK " + "is not installed. Install it with `pip install tavily-python` " + "(or `uv pip install -e .`), then retry." + ) + else: + self._sdk = self._SDK_CLASS(api_key=key) + self._configured = True + self._key_resolved = True + + @property + def configured(self) -> bool: + """True when a real Tavily client is in use (vs. fallback mode).""" + return self._configured + + def close(self) -> None: + """Close the cache DB and any lazily-created web client.""" + self.close_db() + if self._web_client is not None: + self._web_client.close() + + # -- Schema -------------------------------------------------------------- + + def _ensure_schema(self) -> None: + self.execute(""" + CREATE TABLE IF NOT EXISTS tavily_cache ( + cache_key TEXT PRIMARY KEY, + operation TEXT NOT NULL, + response TEXT NOT NULL, + created_at REAL NOT NULL + ); + CREATE TABLE IF NOT EXISTS tavily_ledger ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + operation TEXT NOT NULL, + credits INTEGER NOT NULL, + created_at REAL NOT NULL + ); + """) + + # -- Cache --------------------------------------------------------------- + + @staticmethod + def _cache_key(operation: str, payload: Any, params: Dict[str, Any]) -> str: + """Hash of operation + normalized payload + result-affecting params. + + Params are part of the key because the same text with a different + search depth or result count is a genuinely different request. + """ + norm = { + "op": operation, + "payload": payload, + "params": {k: params[k] for k in sorted(params)}, + } + raw = json.dumps(norm, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(raw.encode("utf-8")).hexdigest() + + def _cache_get(self, key: str) -> Optional[Dict[str, Any]]: + row = self.query( + "SELECT response, created_at FROM tavily_cache WHERE cache_key = :k", + {"k": key}, + one=True, + ) + if row is None: + return None + if time.time() - row["created_at"] > self._cache_ttl: + # Stale: a miss, overwritten only on re-fetch of THIS key. Rows for + # queries never searched again are never pruned, so the cache file + # grows unbounded over time (disk creep only — results stay correct). + # Add a periodic ``DELETE ... WHERE created_at < cutoff`` sweep if it bites. + return None + return json.loads(row["response"]) + + def _record( + self, key: str, operation: str, response: Dict[str, Any], credits: int + ) -> None: + """Atomically cache the response and append the credit ledger entry.""" + now = time.time() + with self.transaction(): + self.delete("tavily_cache", "cache_key = :k", {"k": key}) + self.insert( + "tavily_cache", + { + "cache_key": key, + "operation": operation, + "response": json.dumps(response), + "created_at": now, + }, + ) + self.insert( + "tavily_ledger", + {"operation": operation, "credits": credits, "created_at": now}, + ) + + # -- Credits / budget ---------------------------------------------------- + + def _credits_used(self) -> int: + row = self.query( + "SELECT COALESCE(SUM(credits), 0) AS total FROM tavily_ledger", one=True + ) + return int(row["total"]) if row else 0 + + @staticmethod + def _estimate_credits(operation: str, depth: str) -> int: + """Pre-call cost estimate (no response yet) from the pricing table.""" + return _CREDIT_COST.get( + (operation, depth), _CREDIT_COST.get((operation, "basic"), 1) + ) + + @classmethod + def _actual_credits( + cls, operation: str, depth: str, response: Dict[str, Any] + ) -> int: + """Credits to record: prefer SDK-reported usage, else the estimate.""" + if isinstance(response, dict): + usage = response.get("usage") + if isinstance(usage, dict) and "credits" in usage: + return int(usage["credits"]) + if "credits" in response: + return int(response["credits"]) + return cls._estimate_credits(operation, depth) + + def _check_budget(self, operation: str, depth: str) -> None: + """Warn near the cap, block (or warn) once a call would exceed it.""" + cap = self._budget.cap + if cap is None: + return + used = self._credits_used() + est = self._estimate_credits(operation, depth) + projected = used + est + + if projected > cap: + msg = ( + f"Tavily budget exceeded: {used} credits used + ~{est} for this " + f"{operation} would reach {projected}, but the cap is {cap}." + ) + if self._budget.block: + raise TavilyBudgetExceeded( + msg + " Raise the cap, or pass block=False / --no-block to " + "warn instead of blocking." + ) + log.warning("%s Proceeding (budget is in warn-only mode).", msg) + elif projected >= self._budget.warn_threshold * cap: + log.warning( + "Tavily budget warning: %d/%d credits used (~%d more for this %s), " + "past the %.0f%% threshold.", + used, + cap, + est, + operation, + self._budget.warn_threshold * 100, + ) + + def usage(self) -> Dict[str, Any]: + """Return a credit-usage summary for this cache DB.""" + used = self._credits_used() + rows = self.query( + "SELECT operation, COUNT(*) AS calls, COALESCE(SUM(credits), 0) AS credits " + "FROM tavily_ledger GROUP BY operation" + ) + cap = self._budget.cap + return { + "total_credits": used, + "cap": cap, + "remaining": (cap - used) if cap is not None else None, + "by_operation": { + r["operation"]: {"calls": r["calls"], "credits": r["credits"]} + for r in rows + }, + } + + # -- Fallback helpers ---------------------------------------------------- + + def _get_web_client(self) -> WebClient: + if self._web_client is None: + self._web_client = WebClient() + return self._web_client + + def _shape_ddg(self, query: str, results: List[Dict[str, str]]) -> Dict[str, Any]: + """Render DuckDuckGo results in the same shape as a Tavily response.""" + return { + "query": query, + "results": [ + { + "title": r.get("title", ""), + "url": r.get("url", ""), + "content": r.get("snippet", ""), + } + for r in results + ], + "source": "duckduckgo", + } + + +class TavilyClient(_TavilyBase): + """Synchronous Tavily wrapper. + + Example: + client = TavilyClient(budget=BudgetConfig(cap=100)) + result = client.search("AMD ROCm latest version") + """ + + _SDK_CLASS = _SdkTavilyClient + + def search( + self, + query: str, + *, + search_depth: str = "basic", + max_results: int = 5, + **kwargs: Any, + ) -> Dict[str, Any]: + if not self._configured: + log.info( + "Tavily connector not configured; using DuckDuckGo for query=%r", + query, + ) + results = self._get_web_client().search_duckduckgo(query, max_results) + return self._shape_ddg(query, results) + + params = {"search_depth": search_depth, "max_results": max_results, **kwargs} + key = self._cache_key("search", _normalize_query(query), params) + cached = self._cache_get(key) + if cached is not None: + return cached + + self._check_budget("search", search_depth) + response = self._sdk.search( + query=query, search_depth=search_depth, max_results=max_results, **kwargs + ) + self._record( + key, + "search", + response, + self._actual_credits("search", search_depth, response), + ) + return response + + def extract( + self, + urls: Union[str, Sequence[str]], + *, + extract_depth: str = "basic", + **kwargs: Any, + ) -> Dict[str, Any]: + urls = [urls] if isinstance(urls, str) else list(urls) + if not self._configured: + raise TavilyConfigError( + "extract requires the Tavily connector (the DuckDuckGo fallback " + "only covers search). Configure it with " + "`gaia connectors configure mcp-tavily --set TAVILY_API_KEY=tvly-...`." + ) + + params = {"extract_depth": extract_depth, **kwargs} + key = self._cache_key("extract", sorted(urls), params) + cached = self._cache_get(key) + if cached is not None: + return cached + + self._check_budget("extract", extract_depth) + response = self._sdk.extract(urls=urls, extract_depth=extract_depth, **kwargs) + self._record( + key, + "extract", + response, + self._actual_credits("extract", extract_depth, response), + ) + return response + + def crawl( + self, url: str, *, extract_depth: str = "basic", **kwargs: Any + ) -> Dict[str, Any]: + if not self._configured: + raise TavilyConfigError( + "crawl requires the Tavily connector. Configure it with " + "`gaia connectors configure mcp-tavily --set TAVILY_API_KEY=tvly-...`." + ) + + params = {"extract_depth": extract_depth, **kwargs} + key = self._cache_key("crawl", url, params) + cached = self._cache_get(key) + if cached is not None: + return cached + + self._check_budget("crawl", extract_depth) + response = self._sdk.crawl(url=url, extract_depth=extract_depth, **kwargs) + self._record( + key, + "crawl", + response, + self._actual_credits("crawl", extract_depth, response), + ) + return response + + def __enter__(self) -> "TavilyClient": + return self + + def __exit__(self, *_: object) -> None: + self.close() + + +class AsyncTavilyClient(_TavilyBase): + """Asynchronous Tavily wrapper for concurrent multi-query research. + + Mirrors :class:`TavilyClient` but awaits the SDK; the synchronous + DuckDuckGo fallback is offloaded with ``asyncio.to_thread`` so it doesn't + block the event loop. + """ + + _SDK_CLASS = _SdkAsyncTavilyClient + + def _resolve_key_eagerly(self) -> None: + # Defer: synchronous key resolution (get_credential_sync) raises inside a + # running event loop. Resolve lazily in _ensure_resolved() via await. + self._resolve_lock = asyncio.Lock() + + async def _ensure_resolved(self) -> None: + """Resolve the API key on first use (idempotent, concurrency-safe).""" + if self._key_resolved: + return + async with self._resolve_lock: + if self._key_resolved: + return + key = ( + self._explicit_api_key + if self._explicit_api_key is not None + else await _load_api_key_async() + ) + self._apply_key(key) + + async def search( + self, + query: str, + *, + search_depth: str = "basic", + max_results: int = 5, + **kwargs: Any, + ) -> Dict[str, Any]: + await self._ensure_resolved() + if not self._configured: + log.info( + "Tavily connector not configured; using DuckDuckGo for query=%r", + query, + ) + results = await asyncio.to_thread( + self._get_web_client().search_duckduckgo, query, max_results + ) + return self._shape_ddg(query, results) + + params = {"search_depth": search_depth, "max_results": max_results, **kwargs} + key = self._cache_key("search", _normalize_query(query), params) + cached = self._cache_get(key) + if cached is not None: + return cached + + self._check_budget("search", search_depth) + response = await self._sdk.search( + query=query, search_depth=search_depth, max_results=max_results, **kwargs + ) + self._record( + key, + "search", + response, + self._actual_credits("search", search_depth, response), + ) + return response + + async def extract( + self, + urls: Union[str, Sequence[str]], + *, + extract_depth: str = "basic", + **kwargs: Any, + ) -> Dict[str, Any]: + await self._ensure_resolved() + urls = [urls] if isinstance(urls, str) else list(urls) + if not self._configured: + raise TavilyConfigError( + "extract requires the Tavily connector (the DuckDuckGo fallback " + "only covers search). Configure it with " + "`gaia connectors configure mcp-tavily --set TAVILY_API_KEY=tvly-...`." + ) + + params = {"extract_depth": extract_depth, **kwargs} + key = self._cache_key("extract", sorted(urls), params) + cached = self._cache_get(key) + if cached is not None: + return cached + + self._check_budget("extract", extract_depth) + response = await self._sdk.extract( + urls=urls, extract_depth=extract_depth, **kwargs + ) + self._record( + key, + "extract", + response, + self._actual_credits("extract", extract_depth, response), + ) + return response + + async def aclose(self) -> None: + """Close the async SDK client, then the cache DB and web client. + + The real ``AsyncTavilyClient`` wraps an ``httpx.AsyncClient`` that must + be awaited shut; without this, long-lived async callers leak connections. + """ + if self._sdk is not None and hasattr(self._sdk, "aclose"): + await self._sdk.aclose() + self.close() + + async def __aenter__(self) -> "AsyncTavilyClient": + await self._ensure_resolved() + return self + + async def __aexit__(self, *_: object) -> None: + await self.aclose() diff --git a/tests/unit/connectors/test_catalog_ledger.py b/tests/unit/connectors/test_catalog_ledger.py index ab4af04e6..265adf70e 100644 --- a/tests/unit/connectors/test_catalog_ledger.py +++ b/tests/unit/connectors/test_catalog_ledger.py @@ -4,15 +4,17 @@ Catalog ledger test (#976, updated #1021). Replaces the previous "legacy ⊆ new" parity check. The MCP catalog has been -intentionally reduced from 22 deployed entries down to 3 entries (#1021 -removed mcp-filesystem and mcp-fetch because no built-in agent consumes them -through the connectors framework; custom agents supply their own -mcp_servers.json instead). This test asserts both ends of that ledger: - - * KEPT_IDS — exactly these 3 ids must remain in - ``connectors.catalog.mcp_servers``. For each, field-by-field equivalence - against the legacy ``mcp.py:_CATALOG`` row of the same name is asserted as - a guard against silent drift during the migration. +intentionally reduced from 22 deployed entries down to 3 carried-over entries +(#1021 removed mcp-filesystem and mcp-fetch because no built-in agent consumes +them through the connectors framework; custom agents supply their own +mcp_servers.json instead), plus mcp-tavily added net-new afterwards. This test +asserts both ends of that ledger: + + * KEPT_IDS — exactly these ids must remain in + ``connectors.catalog.mcp_servers`` (the 3 carried over from the legacy + catalog, plus mcp-tavily). For each carried-over id, field-by-field + equivalence against the legacy ``mcp.py:_CATALOG`` row of the same name is + asserted as a guard against silent drift during the migration. * DELETED_IDS — these 19 ids must NOT be present. Regression guard against accidentally re-introducing untested catalog tiles. @@ -33,6 +35,10 @@ "mcp-github", "mcp-memory", "mcp-git", + # Net-new (not part of the original 22): a real tavily-mcp@latest + # server. The same keyring TAVILY_API_KEY is also read by the + # gaia.web.tavily Python wrapper. + "mcp-tavily", } ) diff --git a/tests/unit/test_tavily_wrapper.py b/tests/unit/test_tavily_wrapper.py new file mode 100644 index 000000000..2ef374024 --- /dev/null +++ b/tests/unit/test_tavily_wrapper.py @@ -0,0 +1,323 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Unit tests for the Tavily wrapper (gaia.web.tavily). + +The ``tavily-python`` SDK is mocked end-to-end via an injected fake client, so +these tests never touch the network or require the SDK to be installed. +""" + +import logging + +import pytest + +import gaia.web.tavily as tav +from gaia.web.tavily import ( + AsyncTavilyClient, + BudgetConfig, + TavilyBudgetExceeded, + TavilyClient, + TavilyConfigError, +) + +# --- Test doubles ----------------------------------------------------------- + + +class FakeSyncSDK: + """Stand-in for tavily.TavilyClient with call counters.""" + + def __init__(self, response=None): + self.response = response or {"results": [{"title": "t", "url": "u"}]} + self.search_calls = 0 + self.extract_calls = 0 + + def search(self, **kwargs): + self.search_calls += 1 + return {**self.response, "query": kwargs.get("query")} + + def extract(self, **kwargs): + self.extract_calls += 1 + return {**self.response, "urls": kwargs.get("urls")} + + +class FakeAsyncSDK: + """Stand-in for tavily.AsyncTavilyClient.""" + + def __init__(self, response=None): + self.response = response or {"results": [{"title": "t", "url": "u"}]} + self.search_calls = 0 + self.closed = False + + async def search(self, **kwargs): + self.search_calls += 1 + return {**self.response, "query": kwargs.get("query")} + + async def aclose(self): + self.closed = True + + +class FakeAsyncSDKClass: + """Stand-in for tavily.AsyncTavilyClient as the ``_SDK_CLASS``. + + Unlike ``FakeAsyncSDK`` (injected ready-made), this is instantiated by the + wrapper itself via ``_SDK_CLASS(api_key=...)``, so it must accept the key. + """ + + def __init__(self, api_key=None): + self.api_key = api_key + self.search_calls = 0 + + async def search(self, **kwargs): + self.search_calls += 1 + return {"results": [], "query": kwargs.get("query")} + + async def aclose(self): + pass + + +class FakeWeb: + """Stand-in for WebClient providing the DuckDuckGo fallback.""" + + def __init__(self): + self.calls = 0 + + def search_duckduckgo(self, query, num_results=5): + self.calls += 1 + return [{"title": "DDG", "url": "http://x", "snippet": "snip"}] + + def close(self): + pass + + +class Clock: + """Controllable replacement for the module's ``time`` reference.""" + + def __init__(self, t=1000.0): + self.t = t + + def time(self): + return self.t + + +@pytest.fixture +def fake_sdk(): + return FakeSyncSDK() + + +def make_client(fake_sdk, **kwargs): + """Configured sync client backed by an in-memory DB and the fake SDK.""" + return TavilyClient(db_path=":memory:", sdk_client=fake_sdk, **kwargs) + + +# --- Caching ---------------------------------------------------------------- + + +def test_search_returns_sdk_response(fake_sdk): + client = make_client(fake_sdk) + result = client.search("hello") + assert result["query"] == "hello" + assert fake_sdk.search_calls == 1 + client.close() + + +def test_cache_hit_skips_second_sdk_call(fake_sdk): + client = make_client(fake_sdk) + client.search("AMD ROCm latest version") + # Same query, only case/whitespace differ → normalized to the same key. + client.search("amd rocm LATEST version ") + assert fake_sdk.search_calls == 1 + client.close() + + +def test_cache_key_includes_params(fake_sdk): + client = make_client(fake_sdk) + client.search("q", max_results=5) + client.search("q", max_results=10) # different param → different request + assert fake_sdk.search_calls == 2 + client.close() + + +def test_cache_expires_after_ttl(fake_sdk, monkeypatch): + clock = Clock(1000.0) + monkeypatch.setattr(tav, "time", clock) + client = make_client(fake_sdk, cache_ttl=60) + + client.search("q") + assert fake_sdk.search_calls == 1 + + clock.t = 1000.0 + 61 # advance past the TTL + client.search("q") + assert fake_sdk.search_calls == 2 # stale → re-fetched + client.close() + + +# --- Credit ledger ---------------------------------------------------------- + + +def test_ledger_tracks_credits_by_depth(fake_sdk): + client = make_client(fake_sdk) + client.search("a", search_depth="advanced") # 2 credits + client.search("b", search_depth="basic") # 1 credit + usage = client.usage() + assert usage["total_credits"] == 3 + assert usage["by_operation"]["search"]["calls"] == 2 + client.close() + + +def test_credits_read_from_response_usage(): + sdk = FakeSyncSDK(response={"results": [], "usage": {"credits": 7}}) + client = make_client(sdk) + client.search("q") + assert client.usage()["total_credits"] == 7 + client.close() + + +def test_cached_calls_are_not_remetered(fake_sdk): + client = make_client(fake_sdk) + client.search("q") + client.search("q") # cache hit + assert client.usage()["total_credits"] == 1 # charged once, not twice + client.close() + + +# --- Budget gate ------------------------------------------------------------ + + +def test_budget_warns_near_threshold(fake_sdk, caplog): + client = make_client(fake_sdk, budget=BudgetConfig(cap=10, warn_threshold=0.8)) + client.insert( + "tavily_ledger", {"operation": "search", "credits": 8, "created_at": 0} + ) + with caplog.at_level(logging.WARNING, logger="gaia.web.tavily"): + client.search("q") # projected 9/10 → past 80% + assert "budget warning" in caplog.text.lower() + assert fake_sdk.search_calls == 1 # warned but proceeded + client.close() + + +def test_budget_blocks_over_cap(fake_sdk): + client = make_client(fake_sdk, budget=BudgetConfig(cap=0)) + with pytest.raises(TavilyBudgetExceeded): + client.search("q") + assert fake_sdk.search_calls == 0 # blocked before spending + assert client.usage()["total_credits"] == 0 + client.close() + + +def test_budget_warn_only_mode_does_not_block(fake_sdk, caplog): + client = make_client(fake_sdk, budget=BudgetConfig(cap=0, block=False)) + with caplog.at_level(logging.WARNING, logger="gaia.web.tavily"): + result = client.search("q") # over cap but warn-only + assert result["query"] == "q" + assert fake_sdk.search_calls == 1 + assert "warn-only" in caplog.text.lower() + client.close() + + +def test_unlimited_budget_never_blocks(fake_sdk): + client = make_client(fake_sdk, budget=BudgetConfig(cap=None)) + for i in range(10): + client.search(f"q{i}") + assert fake_sdk.search_calls == 10 + client.close() + + +# --- DuckDuckGo fallback ---------------------------------------------------- + + +def test_unconfigured_search_falls_back_to_ddg(monkeypatch): + monkeypatch.setattr(tav, "_load_api_key", lambda: None) + web = FakeWeb() + client = TavilyClient(db_path=":memory:", web_client=web) + assert client.configured is False + + result = client.search("anything") + assert result["source"] == "duckduckgo" + assert result["results"][0]["url"] == "http://x" + assert web.calls == 1 + assert client.usage()["total_credits"] == 0 # DDG is free → no ledger entry + client.close() + + +def test_unconfigured_extract_raises(monkeypatch): + monkeypatch.setattr(tav, "_load_api_key", lambda: None) + client = TavilyClient(db_path=":memory:", web_client=FakeWeb()) + with pytest.raises(TavilyConfigError): + client.extract("http://example.com") + client.close() + + +def test_configured_but_sdk_missing_raises(monkeypatch): + # Connector configured (key present) but tavily-python not installed. + monkeypatch.setattr(TavilyClient, "_SDK_CLASS", None) + with pytest.raises(TavilyConfigError, match="tavily-python"): + TavilyClient(db_path=":memory:", api_key="tvly-xxx") + + +# --- Async client ----------------------------------------------------------- + + +async def test_async_search_caches(): + sdk = FakeAsyncSDK() + client = AsyncTavilyClient(db_path=":memory:", sdk_client=sdk) + await client.search("q") + await client.search("q") # cache hit + assert sdk.search_calls == 1 + client.close() + + +async def test_async_unconfigured_falls_back_to_ddg(monkeypatch): + async def no_key(): + return None + + monkeypatch.setattr(tav, "_load_api_key_async", no_key) + web = FakeWeb() + client = AsyncTavilyClient(db_path=":memory:", web_client=web) + result = await client.search("anything") + assert result["source"] == "duckduckgo" + assert web.calls == 1 + client.close() + + +async def test_async_configured_construction_in_loop_does_not_raise(monkeypatch): + """Regression: constructing the async client inside a running event loop with + a configured connector must not raise. + + Pre-fix, ``__init__`` resolved the key synchronously via + ``get_credential_sync()``, which raises ``RuntimeError`` inside a running + loop. Resolution is now deferred to the async path on first use. + """ + import gaia.connectors.handler as handler_mod + import gaia.connectors.mcp_server as mcp_mod + + monkeypatch.setattr(mcp_mod, "is_mcp_server_configured", lambda _cid: True) + + async def fake_get_credential(_connector_id, **_kwargs): + return {"env": {"TAVILY_API_KEY": "tvly-test"}} + + monkeypatch.setattr(handler_mod, "get_credential", fake_get_credential) + monkeypatch.setattr(AsyncTavilyClient, "_SDK_CLASS", FakeAsyncSDKClass) + + # Construction must not raise inside the loop... + client = AsyncTavilyClient(db_path=":memory:") + # ...and the key resolves on first use via the async path. + result = await client.search("q") + assert client.configured is True + assert result["query"] == "q" + assert isinstance(client._sdk, FakeAsyncSDKClass) + assert client._sdk.api_key == "tvly-test" + await client.aclose() + + +def test_sync_context_manager_closes(fake_sdk): + with make_client(fake_sdk) as client: + client.search("q") + assert client.db_ready is False # __exit__ closed the cache DB + + +async def test_async_context_manager_awaits_sdk_aclose(): + sdk = FakeAsyncSDK() + async with AsyncTavilyClient(db_path=":memory:", sdk_client=sdk) as client: + await client.search("q") + assert sdk.closed is True # __aexit__ -> aclose() awaited the SDK + assert client.db_ready is False