Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 44 additions & 38 deletions src/gaia/agents/base/memory_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _safe_json_loads(value) -> object:
class MemoryStore:
"""Pure SQLite storage for agent memory. No agent dependencies."""

def __init__(self, db_path: Path = None):
def __init__(self, db_path: Path | None = None):
"""Open/create DB at db_path. Default: ~/.gaia/memory.db

Uses WAL mode. Thread-safe via threading.Lock.
Expand Down Expand Up @@ -503,8 +503,8 @@ def store_turn(

def get_history(
self,
session_id: str = None,
context: str = None,
session_id: str | None = None,
context: str | None = None,
limit: int = 20,
) -> List[Dict]:
"""Retrieve recent conversation turns, ordered oldest-first."""
Expand Down Expand Up @@ -554,7 +554,7 @@ def get_history(
def search_conversations(
self,
query: str,
context: str = None,
context: str | None = None,
limit: int = 10,
) -> List[Dict]:
"""FTS5 keyword search across conversation content.
Expand Down Expand Up @@ -623,7 +623,7 @@ def _fts5_search_conversations_locked(
def get_recent_conversations(
self,
days: int = 7,
context: str = None,
context: str | None = None,
limit: int = 50,
) -> List[Dict]:
"""Get conversations from the last N days (timestamp-based).
Expand Down Expand Up @@ -674,14 +674,14 @@ def store(
self,
category: str,
content: str,
domain: str = None,
domain: str | None = None,
metadata: dict = None,
confidence: float = 0.5,
due_at: str = None,
due_at: str | None = None,
source: str = "tool",
context: str = "global",
sensitive: bool = False,
entity: str = None,
entity: str | None = None,
) -> str:
"""Store a knowledge entry with deduplication.

Expand Down Expand Up @@ -823,7 +823,7 @@ def store(
return knowledge_id

def _find_similar_locked(
self, content: str, category: str, context: str, entity: str = None
self, content: str, category: str, context: str, entity: str | None = None
) -> Optional[str]:
"""Find existing knowledge with >80% word overlap in same category+context+entity.

Expand Down Expand Up @@ -921,13 +921,13 @@ def _update_knowledge_fts_locked(self, knowledge_id: str):
def search(
self,
query: str,
category: str = None,
context: str = None,
entity: str = None,
category: str | None = None,
context: str | None = None,
entity: str | None = None,
include_sensitive: bool = False,
top_k: int = 5,
time_from: str = None,
time_to: str = None,
time_from: str | None = None,
time_to: str | None = None,
) -> List[Dict]:
"""FTS5 search. AND semantics, OR fallback. BM25 ranking.

Expand Down Expand Up @@ -1066,7 +1066,11 @@ def _fts5_search_knowledge_locked(
# ==================================================================

def get_by_category(
self, category: str, context: str = None, domain: str = None, limit: int = 10
self,
category: str,
context: str | None = None,
domain: str | None = None,
limit: int = 10,
) -> List[Dict]:
"""Get active knowledge entries by category, optionally filtered by context and domain."""
conditions = ["category = ?", "superseded_by IS NULL"]
Expand Down Expand Up @@ -1160,7 +1164,7 @@ def get_upcoming(
self,
within_days: int = 7,
include_overdue: bool = True,
context: str = None,
context: str | None = None,
limit: int = 10,
) -> List[Dict]:
"""Get time-sensitive items due within N days (or overdue).
Expand Down Expand Up @@ -1216,16 +1220,16 @@ def get_upcoming(
def update(
self,
knowledge_id: str,
content: str = None,
category: str = None,
domain: str = None,
content: str | None = None,
category: str | None = None,
domain: str | None = None,
metadata: dict = None,
context: str = None,
sensitive: bool = None,
entity: str = None,
due_at: str = None,
reminded_at: str = None,
superseded_by: str = None,
context: str | None = None,
sensitive: bool | None = None,
entity: str | None = None,
due_at: str | None = None,
reminded_at: str | None = None,
superseded_by: str | None = None,
) -> bool:
"""Update an existing knowledge entry. Only provided fields are changed.

Expand Down Expand Up @@ -1389,13 +1393,13 @@ def store_embedding(self, knowledge_id: str, embedding: bytes) -> bool:

def get_items_with_embeddings(
self,
category: str = None,
context: str = None,
entity: str = None,
category: str | None = None,
context: str | None = None,
entity: str | None = None,
include_sensitive: bool = False,
top_k: int = 100,
time_from: str = None,
time_to: str = None,
time_from: str | None = None,
time_to: str | None = None,
) -> List[Dict]:
"""Return active knowledge items that have stored embeddings.

Expand Down Expand Up @@ -1516,7 +1520,7 @@ def backfill_embeddings(
return {"backfilled": backfilled, "total_without": total_without}

def get_items_for_reconciliation(
self, context: str = None, limit: int = 100
self, context: str | None = None, limit: int = 100
) -> List[Dict]:
"""Get active knowledge items with embeddings for pairwise comparison.

Expand Down Expand Up @@ -1620,8 +1624,8 @@ def log_tool_call(
args: dict,
result_summary: str,
success: bool,
error: str = None,
duration_ms: int = None,
error: str | None = None,
duration_ms: int | None = None,
) -> None:
"""Log a tool call to tool_history."""
now = _now_iso()
Expand Down Expand Up @@ -1661,7 +1665,9 @@ def log_tool_call(
self._conn.rollback()
raise

def get_tool_errors(self, tool_name: str = None, limit: int = 10) -> List[Dict]:
def get_tool_errors(
self, tool_name: str | None = None, limit: int = 10
) -> List[Dict]:
"""Get recent failed tool calls, newest first."""
if tool_name is not None:
sql = """
Expand Down Expand Up @@ -1890,10 +1896,10 @@ def get_stats(self) -> Dict:
def get_all_knowledge(
self,
category: Optional[Union[str, List[str]]] = None,
context: str = None,
entity: str = None,
sensitive: bool = None,
search: str = None,
context: str | None = None,
entity: str | None = None,
sensitive: bool | None = None,
search: str | None = None,
sort_by: str = "updated_at",
order: str = "desc",
offset: int = 0,
Expand Down
10 changes: 5 additions & 5 deletions src/gaia/agents/base/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
logger = logging.getLogger(__name__)

# Tool registry to store registered tools
_TOOL_REGISTRY = {}
_TOOL_REGISTRY: dict[str, dict] = {}


def tool(
func: Callable = None,
func: Callable | None = None,
*,
atomic: bool = False,
display_label: str | None = None,
Expand Down Expand Up @@ -109,7 +109,7 @@ def get_tool_display_name(tool_name: str) -> str:
tool = _TOOL_REGISTRY.get(tool_name)
if not tool:
return tool_name
return tool.get("display_name", tool_name)
return tool.get("display_name", tool_name) # type: ignore[no-any-return]


def get_tool_display_label(tool_name: str) -> str:
Expand All @@ -120,8 +120,8 @@ def get_tool_display_label(tool_name: str) -> str:
"""
tool = _TOOL_REGISTRY.get(tool_name)
if not tool:
return None
return tool.get("display_label")
return None # type: ignore[return-value]
return tool.get("display_label") # type: ignore[return-value]


def get_tool_metadata(tool_name: str):
Expand Down
2 changes: 1 addition & 1 deletion src/gaia/agents/chat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def cleanup_old_sessions(

except Exception as e:
logger.error(f"Error during session cleanup: {e}")
return {"error": str(e), "total_deleted": 0, "remaining_sessions": 0}
return {"error": str(e), "total_deleted": 0, "remaining_sessions": 0} # type: ignore[dict-item]

def clear_path_permissions(self):
"""Clear all cached path permissions."""
Expand Down
6 changes: 3 additions & 3 deletions src/gaia/agents/code/orchestration/checklist_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,15 +476,15 @@ def _extract_response_text(self, response: Any) -> str:

# Handle response objects with text attribute
if hasattr(response, "text"):
return response.text
return response.text # type: ignore[no-any-return]

# Handle response objects with content attribute
if hasattr(response, "content"):
return response.content
return response.content # type: ignore[no-any-return]

# Handle dict-like responses
if isinstance(response, dict):
return response.get("text", response.get("content", str(response)))
return response.get("text", response.get("content", str(response))) # type: ignore[return-value]

return str(response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from pathlib import Path
from typing import List
from typing import Any, Dict, List, cast

from ..steps.base import UserContext
from ..workflows.base import WorkflowPhase
Expand Down Expand Up @@ -103,4 +103,4 @@ def get_validation_config(self, phase_name: str) -> dict:
"test_command": "pytest",
},
}
return configs.get(phase_name, {})
return cast(Dict[str, Any], configs.get(phase_name, {}))
8 changes: 4 additions & 4 deletions src/gaia/agents/code/tools/validation_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _validate_requirements(self, req_file: Path, fix: bool) -> Dict[str, Any]:
Returns:
Dictionary with validation results
"""
return self.requirements_validator.validate(req_file, fix)
return self.requirements_validator.validate(req_file, fix) # type: ignore[no-any-return, attr-defined]

def _validate_python_files(
self, py_files: List[Path], _fix: bool
Expand All @@ -84,7 +84,7 @@ def _validate_python_files(
content = py_file.read_text()

# Validate syntax
syntax_result = self.syntax_validator.validate_dict(content)
syntax_result = self.syntax_validator.validate_dict(content) # type: ignore[attr-defined]
if not syntax_result["is_valid"]:
errors.extend(
[f"{py_file}: {err}" for err in syntax_result.get("errors", [])]
Expand Down Expand Up @@ -114,7 +114,7 @@ def _check_antipatterns(self, _file_path: Path, content: str) -> Dict[str, Any]:
Returns:
Dictionary with antipattern check results
"""
return self.antipattern_checker.check_dict(content)
return self.antipattern_checker.check_dict(content) # type: ignore[no-any-return, attr-defined]

def _validate_python_syntax(self, code: str) -> Dict[str, Any]:
"""Validate Python code syntax (delegates to validator).
Expand All @@ -125,7 +125,7 @@ def _validate_python_syntax(self, code: str) -> Dict[str, Any]:
Returns:
Dictionary with validation results
"""
return self.syntax_validator.validate_dict(code)
return self.syntax_validator.validate_dict(code) # type: ignore[no-any-return, attr-defined]

def _validate_javascript_files(
self, js_files: List[Path], _fix: bool
Expand Down
3 changes: 1 addition & 2 deletions src/gaia/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ def stop_server(port: int = 8080) -> None:
check=False,
)

pids = result.stdout.strip().split("\n")
pids = [pid for pid in pids if pid] # Filter empty strings
pids = {pid for pid in result.stdout.strip().split("\n") if pid}

if pids:
for pid in pids:
Expand Down
6 changes: 3 additions & 3 deletions src/gaia/connectors/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _get():
AuthRequiredError.Reason.REAUTH_REQUIRED, provider=provider
)

return blob
return blob # type: ignore[no-any-return]


def peek_connection(
Expand Down Expand Up @@ -271,7 +271,7 @@ def _get():
if raw is None:
return None
try:
return json.loads(raw)
return json.loads(raw) # type: ignore[no-any-return]
except json.JSONDecodeError:
# Corrupt blob — caller treats as "not configured" without
# rewriting state. ``load_connection`` (auth path) still clears
Expand Down Expand Up @@ -342,7 +342,7 @@ def _get():
if raw is None:
return None
try:
return json.loads(raw)
return json.loads(raw) # type: ignore[no-any-return]
except json.JSONDecodeError:
return None

Expand Down
7 changes: 6 additions & 1 deletion src/gaia/database/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def query(
)
"""
self._require_db()
assert self._db is not None
cursor = self._db.execute(sql, params or {})
rows = [dict(row) for row in cursor.fetchall()]
if one:
Expand All @@ -140,13 +141,14 @@ def insert(self, table: str, data: Dict[str, Any]) -> int:
})
"""
self._require_db()
assert self._db is not None
cols = ", ".join(data.keys())
placeholders = ", ".join(f":{k}" for k in data.keys())
sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders})"
cursor = self._db.execute(sql, data)
if not self._in_tx:
self._db.commit()
return cursor.lastrowid
return cursor.lastrowid or 0

def update(
self,
Expand Down Expand Up @@ -176,6 +178,7 @@ def update(
)
"""
self._require_db()
assert self._db is not None
# Prefix data params with __set_ to avoid collision with where params
set_clause = ", ".join(f"{k} = :__set_{k}" for k in data.keys())
merged_params = {f"__set_{k}": v for k, v in data.items()}
Expand All @@ -202,6 +205,7 @@ def delete(self, table: str, where: str, params: Dict[str, Any]) -> int:
count = self.delete("sessions", "expires_at < :now", {"now": now})
"""
self._require_db()
assert self._db is not None
sql = f"DELETE FROM {table} WHERE {where}"
cursor = self._db.execute(sql, params)
if not self._in_tx:
Expand Down Expand Up @@ -261,6 +265,7 @@ def execute(self, sql: str) -> None:
''')
"""
self._require_db()
assert self._db is not None
if self._in_tx:
raise RuntimeError(
"execute() cannot be called inside a transaction() block. "
Expand Down
Loading
Loading