Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions hindsight-api-slim/hindsight_api/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3541,9 +3541,18 @@ async def api_graph(
tags_match: str = "all_strict",
document_id: str | None = None,
chunk_id: str | None = None,
include_entity_data: bool = True,
request_context: RequestContext = Depends(get_request_context),
):
"""Get graph data from database, filtered by bank_id and optionally by type."""
"""Get graph data from database, filtered by bank_id and optionally by type.

Pass include_entity_data=false to skip the entity lookup
(unit_entities ⨝ entities) and the in-memory entity-link inference.
Callers that don't surface entity coloring or entity-link edges (e.g.
the cloud data view) can avoid the join + group cost on banks with
large source_memory_ids arrays. Default is true, preserving the
historical response shape.
"""
try:
data = await app.state.memory.get_graph_data(
bank_id,
Expand All @@ -3554,6 +3563,7 @@ async def api_graph(
tags_match=tags_match,
document_id=document_id,
chunk_id=chunk_id,
include_entity_data=include_entity_data,
request_context=request_context,
)
return data
Expand Down Expand Up @@ -4185,17 +4195,29 @@ async def api_list_banks(request_context: RequestContext = Depends(get_request_c
"/v1/default/banks/{bank_id}/stats",
response_model=BankStatsResponse,
summary="Get statistics for memory bank",
description="Get statistics about nodes and links for a specific agent",
description=(
"Get statistics about nodes and links for a specific agent. "
"Pass include_entity_links=false to skip the entity-link "
"aggregation when the caller doesn't need the per-entity slice; "
"this can significantly reduce response time on banks with many "
"memories and many distinct entities. Default is true, "
"preserving the historical response shape."
),
operation_id="get_agent_stats",
tags=["Banks"],
)
async def api_stats(
bank_id: str,
include_entity_links: bool = True,
request_context: RequestContext = Depends(get_request_context),
):
"""Get statistics about memory nodes and links for a memory bank."""
try:
stats = await app.state.memory.get_bank_stats(bank_id, request_context=request_context)
stats = await app.state.memory.get_bank_stats(
bank_id,
include_entity_links=include_entity_links,
request_context=request_context,
)
nodes_by_type = stats["node_counts"]
links_by_type = stats["link_counts"]
links_by_fact_type = stats["link_counts_by_fact_type"]
Expand Down
47 changes: 33 additions & 14 deletions hindsight-api-slim/hindsight_api/engine/bank_stats_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@


class BankStatsCache:
"""Per-process TTL cache keyed on (schema, bank_id).
"""Per-process TTL cache keyed on (schema, bank_id, *key_suffix).

`ttl_seconds <= 0` disables caching: each call passes straight through to
the loader. `max_entries` bounds memory in environments with many banks.

Callers may extend the key with a `key_suffix` so semantically distinct
variants of the same `(schema, bank_id)` (e.g. with and without an
optional expensive aggregation) get separate cache slots. `invalidate`
clears every variant for a `(schema, bank_id)` so writers don't have to
know which suffixes the read path is using.
"""

def __init__(self, *, ttl_seconds: float, max_entries: int) -> None:
self._ttl = float(ttl_seconds)
self._max_entries = int(max_entries) if max_entries and max_entries > 0 else 0
self._entries: OrderedDict[tuple[str, str], tuple[float, dict[str, Any]]] = OrderedDict()
self._in_flight: dict[tuple[str, str], asyncio.Future[dict[str, Any]]] = {}
self._entries: OrderedDict[tuple[Any, ...], tuple[float, dict[str, Any]]] = OrderedDict()
self._in_flight: dict[tuple[Any, ...], asyncio.Future[dict[str, Any]]] = {}
self._lock = asyncio.Lock()

@property
Expand All @@ -39,7 +45,7 @@ def enabled(self) -> bool:
def _now(self) -> float:
return time.monotonic()

def _get_fresh_unlocked(self, key: tuple[str, str]) -> dict[str, Any] | None:
def _get_fresh_unlocked(self, key: tuple[Any, ...]) -> dict[str, Any] | None:
entry = self._entries.get(key)
if entry is None:
return None
Expand All @@ -52,7 +58,7 @@ def _get_fresh_unlocked(self, key: tuple[str, str]) -> dict[str, Any] | None:
self._entries.move_to_end(key)
return value

def _store_unlocked(self, key: tuple[str, str], value: dict[str, Any]) -> None:
def _store_unlocked(self, key: tuple[Any, ...], value: dict[str, Any]) -> None:
if not self.enabled:
return
self._entries[key] = (self._now() + self._ttl, value)
Expand All @@ -66,16 +72,21 @@ async def get_or_load(
schema: str,
bank_id: str,
loader: Callable[[], Awaitable[dict[str, Any]]],
*,
key_suffix: tuple[Any, ...] = (),
) -> dict[str, Any]:
"""Return cached stats for `(schema, bank_id)` or call `loader()`.
"""Return cached stats for `(schema, bank_id, *key_suffix)` or call `loader()`.

Concurrent misses on the same key are coalesced onto a single
in-flight loader.
in-flight loader. `key_suffix` lets a caller carve out separate
slots for variants that compute different results (e.g. a flag that
toggles an optional expensive aggregation) without confusing each
other's responses.
"""
if not self.enabled:
return await loader()

key = (schema, bank_id)
key = (schema, bank_id, *key_suffix)

async with self._lock:
cached = self._get_fresh_unlocked(key)
Expand Down Expand Up @@ -120,13 +131,21 @@ async def get_or_load(
return value

async def invalidate(self, schema: str, bank_id: str) -> None:
"""Drop any cached stats for `(schema, bank_id)`."""
"""Drop every cached stats variant for `(schema, bank_id)`.

Clears all entries whose key starts with `(schema, bank_id)`, so
writers that don't know which `key_suffix` values the read path is
using still wipe the bank cleanly. Detaches in-flight loaders
rather than cancelling them — existing callers may finish with
their pre-invalidation snapshot while post-invalidation callers
reload.
"""
async with self._lock:
key = (schema, bank_id)
self._entries.pop(key, None)
# Detach rather than cancel: existing callers may finish with the
# snapshot they requested, while post-invalidation callers reload.
self._in_flight.pop(key, None)
prefix = (schema, bank_id)
for key in [k for k in self._entries if k[:2] == prefix]:
self._entries.pop(key, None)
for key in [k for k in self._in_flight if k[:2] == prefix]:
self._in_flight.pop(key, None)

async def clear(self) -> None:
async with self._lock:
Expand Down
12 changes: 12 additions & 0 deletions hindsight-api-slim/hindsight_api/engine/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,13 +448,25 @@ async def get_bank_stats(
self,
bank_id: str,
*,
include_entity_links: bool = True,
request_context: "RequestContext",
) -> dict[str, Any]:
"""
Get statistics about memory nodes and links for a bank.

Args:
bank_id: The memory bank ID.
include_entity_links: When True (default), include the
entity-link total in `link_counts["entity"]`. This count
comes from a join + group on `unit_entities ⨝ memory_units`
that dominates the query cost on banks with many memories
and many distinct entities per memory. Callers that don't
surface entity link counts (or that prefer a fast response
over an approximate cap value) can pass False to skip the
aggregation entirely. When skipped, the "entity" key in
`link_counts` is omitted — matching the existing
"no entity edges" rendering — so downstream readers that
already tolerate a missing key see no behavior change.
request_context: Request context for authentication.

Returns:
Expand Down
124 changes: 79 additions & 45 deletions hindsight-api-slim/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6732,6 +6732,7 @@ async def get_graph_data(
tags_match: str = "all_strict",
document_id: str | None = None,
chunk_id: str | None = None,
include_entity_data: bool = True,
request_context: "RequestContext",
):
"""
Expand Down Expand Up @@ -6937,9 +6938,18 @@ async def get_graph_data(

# Get entity information — only for visible units
# Fetch entities for visible units AND their source memories
# (so observations can inherit entities from source memories)
# (so observations can inherit entities from source memories).
#
# The unit_entities ⨝ entities join is the dominant cost on banks
# where observations carry large source_memory_ids arrays — the
# IN-list balloons to thousands of UUIDs and the join through to
# `entities` to resolve canonical names is the slow query. Callers
# that don't render entity coloring or entity-link edges (e.g. the
# cloud control plane's data view) can opt out via
# include_entity_data=False and skip both the lookup and the
# in-memory entity-link inference below.
entity_lookup_ids = unit_ids + source_memory_ids
if entity_lookup_ids:
if include_entity_data and entity_lookup_ids:
unit_entities = await conn.fetch(
f"""
SELECT ue.unit_id, e.canonical_name
Expand Down Expand Up @@ -7019,9 +7029,10 @@ async def get_graph_data(
# Bounds total edges to ~N * cap per entity instead of N² for hot entities.
max_neighbors_per_unit = 10
entity_to_units_visible: dict[str, list] = {}
for unit_id in unit_ids:
for entity_name in entity_map.get(unit_id, []):
entity_to_units_visible.setdefault(entity_name, []).append(unit_id)
if include_entity_data:
for unit_id in unit_ids:
for entity_name in entity_map.get(unit_id, []):
entity_to_units_visible.setdefault(entity_name, []).append(unit_id)

# Semantic links: pair observations that share at least one source memory
source_to_obs_for_semantic: dict = {}
Expand All @@ -7033,27 +7044,31 @@ async def get_graph_data(
observation_inferred_links = []
seen_inferred: set[tuple] = set()

for entity_name, ent_unit_ids in entity_to_units_visible.items():
n = len(ent_unit_ids)
for i, unit_a in enumerate(ent_unit_ids):
# Sliding window: link unit_a to its next max_neighbors_per_unit
# in the list. Each pair is also "incoming" for the later unit,
# so every unit ends up with up to ~2*max_neighbors_per_unit edges
# for this entity (its successors + its predecessors via their pairs).
for j in range(i + 1, min(i + 1 + max_neighbors_per_unit, n)):
unit_b = ent_unit_ids[j]
pair = (min(str(unit_a), str(unit_b)), max(str(unit_a), str(unit_b)), "entity", entity_name)
if pair not in seen_inferred:
seen_inferred.add(pair)
observation_inferred_links.append(
{
"from_unit_id": unit_a,
"to_unit_id": unit_b,
"link_type": "entity",
"weight": 1.0,
"entity_name": entity_name,
}
)
# Entity-link inference is gated on include_entity_data — without the
# unit_entities lookup above, entity_to_units_visible is empty and this
# loop is a no-op, but explicit gating documents the intent.
if include_entity_data:
for entity_name, ent_unit_ids in entity_to_units_visible.items():
n = len(ent_unit_ids)
for i, unit_a in enumerate(ent_unit_ids):
# Sliding window: link unit_a to its next max_neighbors_per_unit
# in the list. Each pair is also "incoming" for the later unit,
# so every unit ends up with up to ~2*max_neighbors_per_unit edges
# for this entity (its successors + its predecessors via their pairs).
for j in range(i + 1, min(i + 1 + max_neighbors_per_unit, n)):
unit_b = ent_unit_ids[j]
pair = (min(str(unit_a), str(unit_b)), max(str(unit_a), str(unit_b)), "entity", entity_name)
if pair not in seen_inferred:
seen_inferred.add(pair)
observation_inferred_links.append(
{
"from_unit_id": unit_a,
"to_unit_id": unit_b,
"link_type": "entity",
"weight": 1.0,
"entity_name": entity_name,
}
)

for src_id, obs_ids in source_to_obs_for_semantic.items():
for i, obs_a in enumerate(obs_ids):
Expand Down Expand Up @@ -9616,14 +9631,17 @@ async def get_bank_stats(
self,
bank_id: str,
*,
include_entity_links: bool = True,
request_context: "RequestContext",
) -> dict[str, Any]:
"""Get statistics about memory nodes and links for a bank.

Results are served from a short-TTL per-process cache so a polling
client cannot drive the link/unit aggregations multiple times per
second; concurrent misses on the same bank are coalesced onto a
single in-flight loader.
single in-flight loader. The cache slot is keyed on
`include_entity_links` so the two variants don't return each
other's payload.
"""
await self._authenticate_tenant(request_context)
if self._operation_validator:
Expand All @@ -9636,10 +9654,16 @@ async def get_bank_stats(
return await self._bank_stats_cache.get_or_load(
schema,
bank_id,
lambda: self._compute_bank_stats(bank_id),
lambda: self._compute_bank_stats(bank_id, include_entity_links=include_entity_links),
key_suffix=(include_entity_links,),
)

async def _compute_bank_stats(self, bank_id: str) -> dict[str, Any]:
async def _compute_bank_stats(
self,
bank_id: str,
*,
include_entity_links: bool = True,
) -> dict[str, Any]:
backend = await self._get_backend()

async with acquire_with_retry(backend) as conn:
Expand Down Expand Up @@ -9685,23 +9709,33 @@ async def _compute_bank_stats(self, bank_id: str) -> dict[str, Any]:
# slice doubled the join cost and only fed link_counts_by_fact_type
# / link_breakdown, which the UI ignores and the CLI renders into
# sections that degrade gracefully when empty.
max_links_per_entity = 10
entity_total_row = await conn.fetchrow(
f"""
WITH per_entity AS (
SELECT ue.entity_id, COUNT(*) AS n
FROM {fq_table("unit_entities")} ue
JOIN {fq_table("memory_units")} mu ON mu.id = ue.unit_id
WHERE mu.bank_id = $1
GROUP BY ue.entity_id
#
# On banks with many memories and many distinct entities per memory,
# this CTE dominates the cost of /stats. Callers that don't need
# the entity slice can opt out via include_entity_links=False and
# save the join + group entirely. The "entity" key is omitted from
# link_counts on that path, matching the existing "0 → key absent"
# convention below.
if include_entity_links:
max_links_per_entity = 10
entity_total_row = await conn.fetchrow(
f"""
WITH per_entity AS (
SELECT ue.entity_id, COUNT(*) AS n
FROM {fq_table("unit_entities")} ue
JOIN {fq_table("memory_units")} mu ON mu.id = ue.unit_id
WHERE mu.bank_id = $1
GROUP BY ue.entity_id
)
SELECT COALESCE(SUM(LEAST(n - 1, $2)), 0)::bigint AS count
FROM per_entity
""",
bank_id,
max_links_per_entity,
)
SELECT COALESCE(SUM(LEAST(n - 1, $2)), 0)::bigint AS count
FROM per_entity
""",
bank_id,
max_links_per_entity,
)
entity_link_total = int(entity_total_row["count"] or 0) if entity_total_row else 0
entity_link_total = int(entity_total_row["count"] or 0) if entity_total_row else 0
else:
entity_link_total = 0

link_counts: dict[str, int] = {row["link_type"]: row["count"] for row in non_entity_link_rows}
if entity_link_total > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,14 @@ async def extract_facts_from_text(
total_usage = total_usage + chunk_usage

if failed_chunks:
failed_summary = ", ".join(f"chunk {idx}: {type(err).__name__}" for idx, err in failed_chunks[:5])
# Include the exception message — not just the type — so operators
# can tell a structured-JSON parse failure apart from a rate limit
# apart from a network 5xx, all of which can surface as the same
# exception types. The error_message we propagate to the
# async_operations row is the only inspection surface a worker-side
# failure leaves behind, and a bare "chunk 0: RuntimeError" is not
# actionable.
failed_summary = ", ".join(f"chunk {idx}: {type(err).__name__}: {err}" for idx, err in failed_chunks[:5])
quota_errors = [err for _, err in failed_chunks if isinstance(err, ProviderRateLimitResetError)]
if quota_errors and len(quota_errors) == len(failed_chunks):
retry_at = max(err.retry_at for err in quota_errors)
Expand Down
Loading