Skip to content
Merged
82 changes: 44 additions & 38 deletions src/gaia/agents/base/memory_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _safe_json_loads(value) -> object:
class MemoryStore:
"""Pure SQLite storage for agent memory. No agent dependencies."""

def __init__(self, db_path: Path = None):
def __init__(self, db_path: Path | None = None):
"""Open/create DB at db_path. Default: ~/.gaia/memory.db

Uses WAL mode. Thread-safe via threading.Lock.
Expand Down Expand Up @@ -503,8 +503,8 @@ def store_turn(

def get_history(
self,
session_id: str = None,
context: str = None,
session_id: str | None = None,
context: str | None = None,
limit: int = 20,
) -> List[Dict]:
"""Retrieve recent conversation turns, ordered oldest-first."""
Expand Down Expand Up @@ -554,7 +554,7 @@ def get_history(
def search_conversations(
self,
query: str,
context: str = None,
context: str | None = None,
limit: int = 10,
) -> List[Dict]:
"""FTS5 keyword search across conversation content.
Expand Down Expand Up @@ -623,7 +623,7 @@ def _fts5_search_conversations_locked(
def get_recent_conversations(
self,
days: int = 7,
context: str = None,
context: str | None = None,
limit: int = 50,
) -> List[Dict]:
"""Get conversations from the last N days (timestamp-based).
Expand Down Expand Up @@ -674,14 +674,14 @@ def store(
self,
category: str,
content: str,
domain: str = None,
domain: str | None = None,
metadata: dict = None,
confidence: float = 0.5,
due_at: str = None,
due_at: str | None = None,
source: str = "tool",
context: str = "global",
sensitive: bool = False,
entity: str = None,
entity: str | None = None,
) -> str:
"""Store a knowledge entry with deduplication.

Expand Down Expand Up @@ -823,7 +823,7 @@ def store(
return knowledge_id

def _find_similar_locked(
self, content: str, category: str, context: str, entity: str = None
self, content: str, category: str, context: str, entity: str | None = None
) -> Optional[str]:
"""Find existing knowledge with >80% word overlap in same category+context+entity.

Expand Down Expand Up @@ -921,13 +921,13 @@ def _update_knowledge_fts_locked(self, knowledge_id: str):
def search(
self,
query: str,
category: str = None,
context: str = None,
entity: str = None,
category: str | None = None,
context: str | None = None,
entity: str | None = None,
include_sensitive: bool = False,
top_k: int = 5,
time_from: str = None,
time_to: str = None,
time_from: str | None = None,
time_to: str | None = None,
) -> List[Dict]:
"""FTS5 search. AND semantics, OR fallback. BM25 ranking.

Expand Down Expand Up @@ -1066,7 +1066,11 @@ def _fts5_search_knowledge_locked(
# ==================================================================

def get_by_category(
self, category: str, context: str = None, domain: str = None, limit: int = 10
self,
category: str,
context: str | None = None,
domain: str | None = None,
limit: int = 10,
) -> List[Dict]:
"""Get active knowledge entries by category, optionally filtered by context and domain."""
conditions = ["category = ?", "superseded_by IS NULL"]
Expand Down Expand Up @@ -1160,7 +1164,7 @@ def get_upcoming(
self,
within_days: int = 7,
include_overdue: bool = True,
context: str = None,
context: str | None = None,
limit: int = 10,
) -> List[Dict]:
"""Get time-sensitive items due within N days (or overdue).
Expand Down Expand Up @@ -1216,16 +1220,16 @@ def get_upcoming(
def update(
self,
knowledge_id: str,
content: str = None,
category: str = None,
domain: str = None,
content: str | None = None,
category: str | None = None,
domain: str | None = None,
metadata: dict = None,
context: str = None,
sensitive: bool = None,
entity: str = None,
due_at: str = None,
reminded_at: str = None,
superseded_by: str = None,
context: str | None = None,
sensitive: bool | None = None,
entity: str | None = None,
due_at: str | None = None,
reminded_at: str | None = None,
superseded_by: str | None = None,
) -> bool:
"""Update an existing knowledge entry. Only provided fields are changed.

Expand Down Expand Up @@ -1389,13 +1393,13 @@ def store_embedding(self, knowledge_id: str, embedding: bytes) -> bool:

def get_items_with_embeddings(
self,
category: str = None,
context: str = None,
entity: str = None,
category: str | None = None,
context: str | None = None,
entity: str | None = None,
include_sensitive: bool = False,
top_k: int = 100,
time_from: str = None,
time_to: str = None,
time_from: str | None = None,
time_to: str | None = None,
) -> List[Dict]:
"""Return active knowledge items that have stored embeddings.

Expand Down Expand Up @@ -1516,7 +1520,7 @@ def backfill_embeddings(
return {"backfilled": backfilled, "total_without": total_without}

def get_items_for_reconciliation(
self, context: str = None, limit: int = 100
self, context: str | None = None, limit: int = 100
) -> List[Dict]:
"""Get active knowledge items with embeddings for pairwise comparison.

Expand Down Expand Up @@ -1620,8 +1624,8 @@ def log_tool_call(
args: dict,
result_summary: str,
success: bool,
error: str = None,
duration_ms: int = None,
error: str | None = None,
duration_ms: int | None = None,
) -> None:
"""Log a tool call to tool_history."""
now = _now_iso()
Expand Down Expand Up @@ -1661,7 +1665,9 @@ def log_tool_call(
self._conn.rollback()
raise

def get_tool_errors(self, tool_name: str = None, limit: int = 10) -> List[Dict]:
def get_tool_errors(
self, tool_name: str | None = None, limit: int = 10
) -> List[Dict]:
"""Get recent failed tool calls, newest first."""
if tool_name is not None:
sql = """
Expand Down Expand Up @@ -1890,10 +1896,10 @@ def get_stats(self) -> Dict:
def get_all_knowledge(
self,
category: Optional[Union[str, List[str]]] = None,
context: str = None,
entity: str = None,
sensitive: bool = None,
search: str = None,
context: str | None = None,
entity: str | None = None,
sensitive: bool | None = None,
search: str | None = None,
sort_by: str = "updated_at",
order: str = "desc",
offset: int = 0,
Expand Down
10 changes: 5 additions & 5 deletions src/gaia/agents/base/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
logger = logging.getLogger(__name__)

# Tool registry to store registered tools
_TOOL_REGISTRY = {}
_TOOL_REGISTRY: dict[str, dict] = {}


def tool(
func: Callable = None,
func: Callable | None = None,
*,
atomic: bool = False,
display_label: str | None = None,
Expand Down Expand Up @@ -109,7 +109,7 @@ def get_tool_display_name(tool_name: str) -> str:
tool = _TOOL_REGISTRY.get(tool_name)
if not tool:
return tool_name
return tool.get("display_name", tool_name)
return tool.get("display_name", tool_name) # type: ignore[no-any-return]


def get_tool_display_label(tool_name: str) -> str:
Expand All @@ -120,8 +120,8 @@ def get_tool_display_label(tool_name: str) -> str:
"""
tool = _TOOL_REGISTRY.get(tool_name)
if not tool:
return None
return tool.get("display_label")
return None # type: ignore[return-value]
return tool.get("display_label") # type: ignore[return-value]


def get_tool_metadata(tool_name: str):
Expand Down
2 changes: 1 addition & 1 deletion src/gaia/agents/chat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def cleanup_old_sessions(

except Exception as e:
logger.error(f"Error during session cleanup: {e}")
return {"error": str(e), "total_deleted": 0, "remaining_sessions": 0}
return {"error": str(e), "total_deleted": 0, "remaining_sessions": 0} # type: ignore[dict-item]

def clear_path_permissions(self):
"""Clear all cached path permissions."""
Expand Down
6 changes: 3 additions & 3 deletions src/gaia/agents/code/orchestration/checklist_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,15 +476,15 @@ def _extract_response_text(self, response: Any) -> str:

# Handle response objects with text attribute
if hasattr(response, "text"):
return response.text
return response.text # type: ignore[no-any-return]

# Handle response objects with content attribute
if hasattr(response, "content"):
return response.content
return response.content # type: ignore[no-any-return]

# Handle dict-like responses
if isinstance(response, dict):
return response.get("text", response.get("content", str(response)))
return response.get("text", response.get("content", str(response))) # type: ignore[return-value]

return str(response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from pathlib import Path
from typing import List
from typing import Any, Dict, List, cast

from ..steps.base import UserContext
from ..workflows.base import WorkflowPhase
Expand Down Expand Up @@ -103,4 +103,4 @@ def get_validation_config(self, phase_name: str) -> dict:
"test_command": "pytest",
},
}
return configs.get(phase_name, {})
return cast(Dict[str, Any], configs.get(phase_name, {}))
8 changes: 4 additions & 4 deletions src/gaia/agents/code/tools/validation_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _validate_requirements(self, req_file: Path, fix: bool) -> Dict[str, Any]:
Returns:
Dictionary with validation results
"""
return self.requirements_validator.validate(req_file, fix)
return self.requirements_validator.validate(req_file, fix) # type: ignore[no-any-return, attr-defined]

def _validate_python_files(
self, py_files: List[Path], _fix: bool
Expand All @@ -84,7 +84,7 @@ def _validate_python_files(
content = py_file.read_text()

# Validate syntax
syntax_result = self.syntax_validator.validate_dict(content)
syntax_result = self.syntax_validator.validate_dict(content) # type: ignore[attr-defined]
if not syntax_result["is_valid"]:
errors.extend(
[f"{py_file}: {err}" for err in syntax_result.get("errors", [])]
Expand Down Expand Up @@ -114,7 +114,7 @@ def _check_antipatterns(self, _file_path: Path, content: str) -> Dict[str, Any]:
Returns:
Dictionary with antipattern check results
"""
return self.antipattern_checker.check_dict(content)
return self.antipattern_checker.check_dict(content) # type: ignore[no-any-return, attr-defined]

def _validate_python_syntax(self, code: str) -> Dict[str, Any]:
"""Validate Python code syntax (delegates to validator).
Expand All @@ -125,7 +125,7 @@ def _validate_python_syntax(self, code: str) -> Dict[str, Any]:
Returns:
Dictionary with validation results
"""
return self.syntax_validator.validate_dict(code)
return self.syntax_validator.validate_dict(code) # type: ignore[no-any-return, attr-defined]

def _validate_javascript_files(
self, js_files: List[Path], _fix: bool
Expand Down
3 changes: 1 addition & 2 deletions src/gaia/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ def stop_server(port: int = 8080) -> None:
check=False,
)

pids = result.stdout.strip().split("\n")
pids = [pid for pid in pids if pid] # Filter empty strings
pids = {pid for pid in result.stdout.strip().split("\n") if pid}

if pids:
for pid in pids:
Expand Down
6 changes: 3 additions & 3 deletions src/gaia/connectors/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _get():
AuthRequiredError.Reason.REAUTH_REQUIRED, provider=provider
)

return blob
return blob # type: ignore[no-any-return]


def peek_connection(
Expand Down Expand Up @@ -271,7 +271,7 @@ def _get():
if raw is None:
return None
try:
return json.loads(raw)
return json.loads(raw) # type: ignore[no-any-return]
except json.JSONDecodeError:
# Corrupt blob — caller treats as "not configured" without
# rewriting state. ``load_connection`` (auth path) still clears
Expand Down Expand Up @@ -342,7 +342,7 @@ def _get():
if raw is None:
return None
try:
return json.loads(raw)
return json.loads(raw) # type: ignore[no-any-return]
except json.JSONDecodeError:
return None

Expand Down
7 changes: 6 additions & 1 deletion src/gaia/database/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def query(
)
"""
self._require_db()
assert self._db is not None
cursor = self._db.execute(sql, params or {})
rows = [dict(row) for row in cursor.fetchall()]
if one:
Expand All @@ -140,13 +141,14 @@ def insert(self, table: str, data: Dict[str, Any]) -> int:
})
"""
self._require_db()
assert self._db is not None
cols = ", ".join(data.keys())
placeholders = ", ".join(f":{k}" for k in data.keys())
sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders})"
cursor = self._db.execute(sql, data)
if not self._in_tx:
self._db.commit()
return cursor.lastrowid
return cursor.lastrowid or 0

def update(
self,
Expand Down Expand Up @@ -176,6 +178,7 @@ def update(
)
"""
self._require_db()
assert self._db is not None
# Prefix data params with __set_ to avoid collision with where params
set_clause = ", ".join(f"{k} = :__set_{k}" for k in data.keys())
merged_params = {f"__set_{k}": v for k, v in data.items()}
Expand All @@ -202,6 +205,7 @@ def delete(self, table: str, where: str, params: Dict[str, Any]) -> int:
count = self.delete("sessions", "expires_at < :now", {"now": now})
"""
self._require_db()
assert self._db is not None
sql = f"DELETE FROM {table} WHERE {where}"
cursor = self._db.execute(sql, params)
if not self._in_tx:
Expand Down Expand Up @@ -261,6 +265,7 @@ def execute(self, sql: str) -> None:
''')
"""
self._require_db()
assert self._db is not None
if self._in_tx:
raise RuntimeError(
"execute() cannot be called inside a transaction() block. "
Expand Down
Loading
Loading