diff --git a/hindsight-api-slim/hindsight_api/engine/consolidation/consolidator.py b/hindsight-api-slim/hindsight_api/engine/consolidation/consolidator.py index 14e0c3f70..f66806eae 100644 --- a/hindsight-api-slim/hindsight_api/engine/consolidation/consolidator.py +++ b/hindsight-api-slim/hindsight_api/engine/consolidation/consolidator.py @@ -32,6 +32,7 @@ from ...config import get_config from ...worker.stage import set_stage +from ..db import DatabaseBackend from ..db_utils import acquire_with_retry from ..llm_trace import ( record_created_memory_ids, @@ -138,10 +139,13 @@ class _DedupOutcome: best_id: str | None merged_text: str should_merge: bool + # The twin's text at probe time. Guards the fold against a concurrent survivor + # rewrite during the connection-free LLM window (set on the two non-None returns). + best_text: str = "" async def _dedup_adjudicate( - conn: "Connection", + pool: DatabaseBackend, memory_engine: "MemoryEngine", bank_id: str, config: Any, @@ -161,6 +165,9 @@ async def _dedup_adjudicate( (used by the UPDATE path, where the anchor row already exists and would self-match at 1.0). ``anchor_emb_str`` reuses an already-computed embedding (the UPDATE path just embedded it); pass None to embed ``anchor_text`` here (the CREATE path). + + The embedder and the LLM both run with NO connection held; only the semantic+BM25 probe + briefly borrows a short-lived connection. """ from ..search.retrieval import retrieve_semantic_bm25_combined @@ -171,9 +178,10 @@ async def _dedup_adjudicate( return _DedupOutcome(best_id=None, merged_text="", should_merge=False) anchor_emb_str = str(embs[0]) tags_match = "all_strict" if tags else "any" - grouped = await retrieve_semantic_bm25_combined( - conn, anchor_emb_str, anchor_text, bank_id, ["observation"], _DEDUP_TOP_K, tags=tags, tags_match=tags_match - ) + async with acquire_with_retry(pool) as conn: + grouped = await retrieve_semantic_bm25_combined( + conn, anchor_emb_str, anchor_text, bank_id, ["observation"], _DEDUP_TOP_K, tags=tags, tags_match=tags_match + ) results = grouped.get("observation", ([], []))[0] best_id: str | None = None best_text = "" @@ -195,12 +203,14 @@ async def _dedup_adjudicate( scope="consolidation_dedup", ) if decision.action != "merge": - return _DedupOutcome(best_id=best_id, merged_text="", should_merge=False) - return _DedupOutcome(best_id=best_id, merged_text=decision.text.strip() or best_text, should_merge=True) + return _DedupOutcome(best_id=best_id, merged_text="", should_merge=False, best_text=best_text) + return _DedupOutcome( + best_id=best_id, merged_text=decision.text.strip() or best_text, should_merge=True, best_text=best_text + ) async def _dedup_reconcile_create( - conn: "Connection", + pool: DatabaseBackend, memory_engine: "MemoryEngine", bank_id: str, config: Any, @@ -216,7 +226,7 @@ async def _dedup_reconcile_create( no near twin or the LLM keeps them distinct. """ outcome = await _dedup_adjudicate( - conn, memory_engine, bank_id, config, dedup_llm_config, create_text, None, tags, exclude_id=None + pool, memory_engine, bank_id, config, dedup_llm_config, create_text, None, tags, exclude_id=None ) if not outcome.should_merge or outcome.best_id is None: return None @@ -224,24 +234,41 @@ async def _dedup_reconcile_create( # Fold the new source facts into the twin and persist the merged text. We keep the twin's # existing embedding: the merged text is >= threshold similar, so the stored vector stays # representative and we avoid a re-embed + a dialect-specific vector UPDATE. - await conn.execute( - f""" - UPDATE {fq_table("memory_units")} - SET text = $1, - source_memory_ids = (SELECT array_agg(DISTINCT e) FROM unnest(source_memory_ids || $2::uuid[]) e), - proof_count = (SELECT count(DISTINCT e) FROM unnest(source_memory_ids || $2::uuid[]) e), - updated_at = now() - WHERE id = $3::uuid - """, - outcome.merged_text, - create_source_ids, - uuid.UUID(outcome.best_id), - ) + async with acquire_with_retry(pool) as conn: + async with conn.transaction(): + # Re-check liveness inside the fold transaction; CREATE performed the slow embed/LLM + # work off-connection, so sources may have been deleted since the decision was made. + live_source_ids = await _filter_live_source_memories(conn, bank_id, create_source_ids) + if not live_source_ids: + return None + folded = await conn.fetchval( + f""" + UPDATE {fq_table("memory_units")} + SET text = $1, + source_memory_ids = (SELECT array_agg(DISTINCT e) FROM unnest(source_memory_ids || $2::uuid[]) e), + proof_count = (SELECT count(DISTINCT e) FROM unnest(source_memory_ids || $2::uuid[]) e), + updated_at = now() + WHERE id = $3::uuid AND text = $4 + RETURNING id + """, + outcome.merged_text, + live_source_ids, + uuid.UUID(outcome.best_id), + outcome.best_text, + ) + if folded is None: + # The twin vanished during the connection-free LLM window. Don't skip the CREATE: + # returning None lets the caller insert the observation so nothing is lost. + logger.debug( + "[CONSOLIDATION] dedup-merge target %s vanished before fold; proceeding with CREATE", + outcome.best_id[:8], + ) + return None return outcome.best_id async def _dedup_reconcile_update( - conn: "Connection", + pool: DatabaseBackend, memory_engine: "MemoryEngine", bank_id: str, config: Any, @@ -262,7 +289,7 @@ async def _dedup_reconcile_update( CREATE path the row already exists, so reconciliation is a fold-and-delete, not a skip. """ outcome = await _dedup_adjudicate( - conn, + pool, memory_engine, bank_id, config, @@ -275,29 +302,70 @@ async def _dedup_reconcile_update( if not outcome.should_merge or outcome.best_id is None: return - # Fold the updated observation's sources into the twin (keeping the twin's embedding, as in + # Fold the updated observation's live sources into the twin (keeping the twin's embedding, as in # the create path) then delete the now-redundant updated row. The all_strict/any tag match # guarantees twin and updated share scope, so dropping the updated row's tags loses no # visibility. Temporal fields follow the surviving twin (minimal scope; matches create). - await conn.execute( - f""" - UPDATE {fq_table("memory_units")} t - SET text = $1, - source_memory_ids = ( - SELECT array_agg(DISTINCT e) FROM unnest(t.source_memory_ids || u.source_memory_ids) e - ), - proof_count = ( - SELECT count(DISTINCT e) FROM unnest(t.source_memory_ids || u.source_memory_ids) e - ), - updated_at = now() - FROM {fq_table("memory_units")} u - WHERE t.id = $2::uuid AND u.id = $3::uuid - """, - outcome.merged_text, - uuid.UUID(outcome.best_id), - uuid.UUID(updated_id), - ) - await _execute_delete_action(conn, bank_id, updated_id) + # The fold + delete share one short transaction so the twin gains the sources exactly as + # the redundant row is removed. + async with acquire_with_retry(pool) as conn: + async with conn.transaction(): + # Snapshot the updated row's sources with a PLAIN read (no FOR UPDATE). Lock order + # must be sources-before-observation: _filter_live_source_memories below takes + # FOR SHARE on the SOURCE rows first, then the fold UPDATE locks the observation + # rows -- the same order as _dedup_reconcile_create and the normal write paths + # (_create_observation_directly / _execute_update_action). Locking the observation + # here (FOR UPDATE) would invert that against the invalidation path (source FOR + # UPDATE in memory_engine.py, then the observation sweep DELETE in fact_storage.py) + # and deadlock. The unlocked snapshot is safe: same-bank consolidation is + # bank-serialized across workers (engine/db/ops_postgresql.py) and overlapping write + # scopes serialize in-process (scope_locks), so no concurrent writer appends sources + # to u between this read and the DELETE below. + updated_row = await conn.fetchrow( + f""" + SELECT source_memory_ids + FROM {fq_table("memory_units")} + WHERE id = $1::uuid AND text = $2 + """, + uuid.UUID(updated_id), + updated_text, + ) + if updated_row is None: + return + live_u_sources = await _filter_live_source_memories( + conn, + bank_id, + list(updated_row["source_memory_ids"] or []), + ) + if not live_u_sources: + return + folded = await conn.fetchval( + f""" + UPDATE {fq_table("memory_units")} t + SET text = $1, + source_memory_ids = ( + SELECT array_agg(DISTINCT e) FROM unnest(t.source_memory_ids || $6::uuid[]) e + ), + proof_count = ( + SELECT count(DISTINCT e) FROM unnest(t.source_memory_ids || $6::uuid[]) e + ), + updated_at = now() + FROM {fq_table("memory_units")} u + WHERE t.id = $2::uuid AND u.id = $3::uuid AND t.text = $4 AND u.text = $5 + RETURNING t.id + """, + outcome.merged_text, + uuid.UUID(outcome.best_id), + uuid.UUID(updated_id), + outcome.best_text, + updated_text, + live_u_sources, + ) + if folded is None: + # Twin or updated row vanished during the LLM window — keep the updated row + # as a distinct observation instead of deleting it unfolded. + return + await _execute_delete_action(conn, bank_id, updated_id) logger.info( "[CONSOLIDATION] dedup-merged updated observation %s into %s (cosine>=%.2f)", updated_id[:8], @@ -437,6 +505,27 @@ async def _filter_live_source_memories( return [mid for mid in source_memory_ids if mid in live] +async def _any_live_source_memory( + conn: "Connection", + bank_id: str, + source_memory_ids: list[uuid.UUID], +) -> bool: + """Cheap, non-locking existence check used as a preflight before embedding. + + Lets the create/update executors skip the (slow) embedder when every source + memory is already gone, restoring the pre-refactor short-circuit. The + authoritative, FOR SHARE liveness check still runs inside the write txn. + """ + if not source_memory_ids: + return False + found = await conn.fetchval( + f"SELECT 1 FROM {fq_table('memory_units')} WHERE id = ANY($1::uuid[]) AND bank_id = $2 LIMIT 1", + source_memory_ids, + bank_id, + ) + return found is not None + + class _CreateAction(BaseModel): text: str source_fact_ids: list[str] # memory UUIDs from the NEW FACTS list @@ -778,7 +867,7 @@ async def _run_consolidation_job( logger.debug(f"Consolidation disabled for bank {bank_id}") return {"status": "disabled", "bank_id": bank_id} - pool = memory_engine._backend + pool = await memory_engine._get_backend() # Get bank profile async with acquire_with_retry(pool) as conn: @@ -991,53 +1080,18 @@ async def _process_one_llm_batch(llm_batch_local: list[dict[str, Any]], batch_nu while pending: sub_batch = pending.pop(0) - async with acquire_with_retry(pool) as conn: - obs_tags_list = _resolve_obs_tags_list(sub_batch[0]) if sub_batch else None - - sub_deleted: int = 0 - sub_llm_failed = False - if obs_tags_list: - sub_results: list[dict[str, Any]] = [] - for obs_tags in obs_tags_list: - pass_results, pass_deleted, pass_failed = await _process_memory_batch( - conn=conn, - memory_engine=memory_engine, - llm_config=llm_config, - bank_id=bank_id, - memories=sub_batch, - request_context=request_context, - perf=batch_perf, - config=config, - obs_tags_override=obs_tags, - ) - sub_deleted += pass_deleted - sub_llm_failed = sub_llm_failed or pass_failed - if not sub_results: - sub_results = pass_results - else: - for i, (existing, new) in enumerate(zip(sub_results, pass_results)): - if existing.get("action") == "skipped" and new.get("action") != "skipped": - sub_results[i] = new - elif existing.get("action") != "skipped" and new.get("action") != "skipped": - existing_created = existing.get( - "created", 1 if existing.get("action") == "created" else 0 - ) - existing_updated = existing.get( - "updated", 1 if existing.get("action") == "updated" else 0 - ) - new_created = new.get("created", 1 if new.get("action") == "created" else 0) - new_updated = new.get("updated", 1 if new.get("action") == "updated" else 0) - total = existing_created + existing_updated + new_created + new_updated - sub_results[i] = { - "action": "multiple", - "created": existing_created + new_created, - "updated": existing_updated + new_updated, - "merged": 0, - "total_actions": total, - } - else: - sub_results, sub_deleted, sub_llm_failed = await _process_memory_batch( - conn=conn, + # No connection is held across the batch: recall, the main LLM call, the + # per-action embeds, and dedup all run connection-free; each helper acquires a + # short-lived connection only around its own SQL. + obs_tags_list = _resolve_obs_tags_list(sub_batch[0]) if sub_batch else None + + sub_deleted: int = 0 + sub_llm_failed = False + if obs_tags_list: + sub_results: list[dict[str, Any]] = [] + for obs_tags in obs_tags_list: + pass_results, pass_deleted, pass_failed = await _process_memory_batch( + pool=pool, memory_engine=memory_engine, llm_config=llm_config, bank_id=bank_id, @@ -1045,7 +1099,44 @@ async def _process_one_llm_batch(llm_batch_local: list[dict[str, Any]], batch_nu request_context=request_context, perf=batch_perf, config=config, + obs_tags_override=obs_tags, ) + sub_deleted += pass_deleted + sub_llm_failed = sub_llm_failed or pass_failed + if not sub_results: + sub_results = pass_results + else: + for i, (existing, new) in enumerate(zip(sub_results, pass_results)): + if existing.get("action") == "skipped" and new.get("action") != "skipped": + sub_results[i] = new + elif existing.get("action") != "skipped" and new.get("action") != "skipped": + existing_created = existing.get( + "created", 1 if existing.get("action") == "created" else 0 + ) + existing_updated = existing.get( + "updated", 1 if existing.get("action") == "updated" else 0 + ) + new_created = new.get("created", 1 if new.get("action") == "created" else 0) + new_updated = new.get("updated", 1 if new.get("action") == "updated" else 0) + total = existing_created + existing_updated + new_created + new_updated + sub_results[i] = { + "action": "multiple", + "created": existing_created + new_created, + "updated": existing_updated + new_updated, + "merged": 0, + "total_actions": total, + } + else: + sub_results, sub_deleted, sub_llm_failed = await _process_memory_batch( + pool=pool, + memory_engine=memory_engine, + llm_config=llm_config, + bank_id=bank_id, + memories=sub_batch, + request_context=request_context, + perf=batch_perf, + config=config, + ) all_deleted += sub_deleted @@ -1370,7 +1461,7 @@ async def _trigger_mental_model_refreshes( Returns: Number of mental models scheduled for refresh """ - pool = memory_engine._backend + pool = await memory_engine._get_backend() # Find mental models with refresh_after_consolidation=true that are actually stale. # The tag filter on the SELECT enforces the security boundary (never look outside the @@ -1443,7 +1534,7 @@ async def _trigger_mental_model_refreshes( async def _process_memory_batch( - conn: "Connection", + pool: DatabaseBackend, memory_engine: "MemoryEngine", llm_config: Any, bank_id: str, @@ -1530,7 +1621,10 @@ async def _process_memory_batch( if max_obs >= 0 and fact_tags: # max_obs == 0 means "no new observations": there are no slots regardless # of the current count, so skip the count query for that case. - current_count = await _count_observations_for_scope(conn, bank_id, fact_tags) if max_obs > 0 else 0 + current_count = 0 + if max_obs > 0: + async with acquire_with_retry(pool) as count_conn: + current_count = await _count_observations_for_scope(count_conn, bank_id, fact_tags) remaining_observation_slots = max(max_obs - current_count, 0) if remaining_observation_slots == 0: logger.info( @@ -1575,17 +1669,20 @@ async def _process_memory_batch( else None ) - # Execute deletes first to free observation slots before creates consume them + # Execute deletes first to free observation slots before creates consume them. Each delete + # is a single fast statement, so the whole loop shares one short-lived connection. deleted_count = 0 - for delete in llm_result.deletes: - # Security: the observation must be present in the unioned recall - if not any(str(obs.id) == delete.observation_id for obs in union_observations): - logger.debug( - f"Batch consolidation: rejected delete — observation {delete.observation_id} not in unioned recall" - ) - continue - await _execute_delete_action(conn=conn, bank_id=bank_id, observation_id=delete.observation_id) - deleted_count += 1 + if llm_result.deletes: + async with acquire_with_retry(pool) as conn: + for delete in llm_result.deletes: + # Security: the observation must be present in the unioned recall + if not any(str(obs.id) == delete.observation_id for obs in union_observations): + logger.debug( + f"Batch consolidation: rejected delete — observation {delete.observation_id} not in unioned recall" + ) + continue + await _execute_delete_action(conn, bank_id, delete.observation_id) + deleted_count += 1 for update in llm_result.updates: source_mems = [mem_by_id[fid] for fid in update.source_fact_ids if fid in mem_by_id] @@ -1600,7 +1697,7 @@ async def _process_memory_batch( continue agg = _aggregate_source_fields(source_mems, tags=fact_tags) updated_emb_str = await _execute_update_action( - conn=conn, + pool=pool, memory_engine=memory_engine, bank_id=bank_id, source_memory_ids=[m["id"] for m in source_mems], @@ -1620,7 +1717,7 @@ async def _process_memory_batch( # source). updated_emb_str is None when the update was skipped — nothing to reconcile. if dedup_enabled and updated_emb_str is not None: await _dedup_reconcile_update( - conn, + pool, memory_engine, bank_id, config, @@ -1668,7 +1765,7 @@ async def _process_memory_batch( # near-identical observation (LLM-adjudicated, 1-by-1) instead of inserting a dup. if dedup_enabled: merged_into = await _dedup_reconcile_create( - conn, memory_engine, bank_id, config, dedup_llm_config, create.text, create_source_ids, agg.tags + pool, memory_engine, bank_id, config, dedup_llm_config, create.text, create_source_ids, agg.tags ) if merged_into is not None: logger.info( @@ -1680,8 +1777,8 @@ async def _process_memory_batch( per_memory_created.add(str(m["id"])) continue - await _execute_create_action( - conn=conn, + action = await _execute_create_action( + pool=pool, memory_engine=memory_engine, bank_id=bank_id, source_memory_ids=create_source_ids, @@ -1693,8 +1790,9 @@ async def _process_memory_batch( mentioned_at=agg.mentioned_at, perf=perf, ) - for m in source_mems: - per_memory_created.add(str(m["id"])) + if action == "created": + for m in source_mems: + per_memory_created.add(str(m["id"])) # Build per-memory result dicts for the stats tracker in the outer loop results: list[dict[str, Any]] = [] @@ -1783,7 +1881,7 @@ async def _append_observation_history( async def _execute_update_action( - conn: "Connection", + pool: DatabaseBackend, memory_engine: "MemoryEngine", bank_id: str, source_memory_ids: list[uuid.UUID], @@ -1802,41 +1900,33 @@ async def _execute_update_action( Extends source_memory_ids with all contributing memories, updates temporal fields (LEAST for occurred_start, GREATEST for occurred_end / mentioned_at), and merges tags. + The embedding is computed off-connection (a slow embedder must never pin a pooled + connection); the liveness check + UPDATE + history + observation_sources sync then run + in one short transaction so they commit atomically. + Returns the observation's freshly-computed embedding (pgvector literal) so the caller can run UPDATE-path dedup without re-embedding, or None when the update was skipped. """ model = next((m for m in observations if str(m.id) == observation_id), None) if not model: logger.debug(f"Update skipped: observation {observation_id} not found in recall results") - return - - live_source_memory_ids = await _filter_live_source_memories(conn, bank_id, source_memory_ids) - if not live_source_memory_ids: - logger.debug( - f"Update skipped: all {len(source_memory_ids)} source memories for observation " - f"{observation_id} were deleted concurrently" - ) - return - source_memory_ids = live_source_memory_ids + return None from ...config import get_config - history_entry = _ObservationHistorySnapshot( - previous_text=model.text, - previous_tags=list(model.tags or []), - previous_occurred_start=model.occurred_start, - previous_occurred_end=model.occurred_end, - previous_mentioned_at=model.mentioned_at, - new_source_memory_ids=[str(mid) for mid in source_memory_ids], - ) - - source_ids = list(model.source_fact_ids or []) + source_memory_ids - - # SECURITY: Merge source fact's tags into existing observation tags so all contributors can see it - existing_tags = set(model.tags or []) - source_tags = set(source_fact_tags or []) - merged_tags = list(existing_tags | source_tags) + # Preflight (non-locking, separate short-lived conn): if every source memory is already + # gone, skip BEFORE the slow embed — restores the pre-refactor short-circuit so a no-op + # update doesn't embed and a failing embedder doesn't raise where it used to skip. + async with acquire_with_retry(pool) as conn: + if not await _any_live_source_memory(conn, bank_id, source_memory_ids): + logger.debug( + f"Update skipped: all {len(source_memory_ids)} source memories for observation " + f"{observation_id} were deleted before embedding" + ) + return None + # Embed off-connection: the new text is known up front and does not depend on + # any DB state, so the (slow) embedder runs before we touch the pool. t0 = time.time() embeddings = await embedding_utils.generate_embeddings_batch(memory_engine.embeddings, [new_text]) embedding_str = str(embeddings[0]) if embeddings else None @@ -1845,60 +1935,89 @@ async def _execute_update_action( config = get_config() - t0 = time.time() - await conn.execute( - f""" - UPDATE {fq_table("memory_units")} - SET text = $1, - embedding = $2::vector, - source_memory_ids = $3, - proof_count = $4, - tags = $9, - updated_at = now(), - occurred_start = LEAST(occurred_start, COALESCE($6, occurred_start)), - occurred_end = GREATEST(occurred_end, COALESCE($7, occurred_end)), - mentioned_at = GREATEST(mentioned_at, COALESCE($8, mentioned_at)) - WHERE id = $5 - """, - new_text, - embedding_str, - source_ids, - len(source_ids), - uuid.UUID(observation_id), - source_occurred_start, - source_occurred_end, - source_mentioned_at, - merged_tags, - ) + async with acquire_with_retry(pool) as conn: + async with conn.transaction(): + # FOR SHARE liveness + the write share one tiny transaction so a concurrent + # delete cannot remove a source row between the check and the UPDATE. + live_source_memory_ids = await _filter_live_source_memories(conn, bank_id, source_memory_ids) + if not live_source_memory_ids: + logger.debug( + f"Update skipped: all {len(source_memory_ids)} source memories for observation " + f"{observation_id} were deleted concurrently" + ) + return None + source_memory_ids = live_source_memory_ids + + history_entry = _ObservationHistorySnapshot( + previous_text=model.text, + previous_tags=list(model.tags or []), + previous_occurred_start=model.occurred_start, + previous_occurred_end=model.occurred_end, + previous_mentioned_at=model.mentioned_at, + new_source_memory_ids=[str(mid) for mid in source_memory_ids], + ) - # Record the pre-update snapshot in the dedicated observation_history table - # (one row per change), then trim to the configured cap. History lived in a - # single unbounded JSONB column before; an often-reinforced observation grew - # it until it crossed Postgres's 256MB jsonb limit and got stuck. - if config.enable_observation_history: - await _append_observation_history( - conn, bank_id, observation_id, history_entry, config.observation_history_max_entries - ) + source_ids = list(model.source_fact_ids or []) + source_memory_ids - # Sync observation_sources junction table (Oracle only — PG uses native array ops). - if memory_engine._backend.ops.uses_observation_sources_table: - obs_uuid = uuid.UUID(observation_id) - await conn.execute( - f"DELETE FROM {fq_table('observation_sources')} WHERE observation_id = $1", - obs_uuid, - ) - if source_ids: - await conn.executemany( + # SECURITY: Merge source fact's tags into existing observation tags so all contributors can see it + existing_tags = set(model.tags or []) + source_tags = set(source_fact_tags or []) + merged_tags = list(existing_tags | source_tags) + + t0 = time.time() + await conn.execute( f""" - INSERT INTO {fq_table("observation_sources")} (observation_id, source_id) - VALUES ($1, $2) - ON CONFLICT (observation_id, source_id) DO NOTHING + UPDATE {fq_table("memory_units")} + SET text = $1, + embedding = $2::vector, + source_memory_ids = $3, + proof_count = $4, + tags = $9, + updated_at = now(), + occurred_start = LEAST(occurred_start, COALESCE($6, occurred_start)), + occurred_end = GREATEST(occurred_end, COALESCE($7, occurred_end)), + mentioned_at = GREATEST(mentioned_at, COALESCE($8, mentioned_at)) + WHERE id = $5 """, - [(obs_uuid, sid) for sid in dict.fromkeys(source_ids)], + new_text, + embedding_str, + source_ids, + len(source_ids), + uuid.UUID(observation_id), + source_occurred_start, + source_occurred_end, + source_mentioned_at, + merged_tags, ) - if perf: - perf.record_timing("db_write", time.time() - t0) + # Record the pre-update snapshot in the dedicated observation_history table + # (one row per change), then trim to the configured cap. History lived in a + # single unbounded JSONB column before; an often-reinforced observation grew + # it until it crossed Postgres's 256MB jsonb limit and got stuck. + if config.enable_observation_history: + await _append_observation_history( + conn, bank_id, observation_id, history_entry, config.observation_history_max_entries + ) + + # Sync observation_sources junction table (Oracle only — PG uses native array ops). + if memory_engine._backend.ops.uses_observation_sources_table: + obs_uuid = uuid.UUID(observation_id) + await conn.execute( + f"DELETE FROM {fq_table('observation_sources')} WHERE observation_id = $1", + obs_uuid, + ) + if source_ids: + await conn.executemany( + f""" + INSERT INTO {fq_table("observation_sources")} (observation_id, source_id) + VALUES ($1, $2) + ON CONFLICT (observation_id, source_id) DO NOTHING + """, + [(obs_uuid, sid) for sid in dict.fromkeys(source_ids)], + ) + + if perf: + perf.record_timing("db_write", time.time() - t0) # Map the updated observation onto the consolidation trace as a produced memory. record_created_memory_ids([observation_id]) @@ -1907,7 +2026,7 @@ async def _execute_update_action( async def _execute_create_action( - conn: "Connection", + pool: DatabaseBackend, memory_engine: "MemoryEngine", bank_id: str, source_memory_ids: list[uuid.UUID], @@ -1918,15 +2037,15 @@ async def _execute_create_action( occurred_end: datetime | None = None, mentioned_at: datetime | None = None, perf: ConsolidationPerfLog | None = None, -) -> None: +) -> str: """ Create a new observation from one or more source memories. Tags are inherited from the source facts (determined algorithmically, not by LLM) - to maintain visibility scope. + to maintain visibility scope. Returns the write action ("created" or "skipped"). """ created = await _create_observation_directly( - conn=conn, + pool=pool, memory_engine=memory_engine, bank_id=bank_id, source_memory_ids=source_memory_ids, @@ -1943,6 +2062,7 @@ async def _execute_create_action( if new_id: record_created_memory_ids([new_id]) logger.debug(f"Created observation from {len(source_memory_ids)} source memories") + return created["action"] async def _execute_delete_action( @@ -2283,7 +2403,7 @@ def _fact_line(m: dict[str, Any]) -> str: async def _create_observation_directly( - conn: "Connection", + pool: DatabaseBackend, memory_engine: "MemoryEngine", bank_id: str, source_memory_ids: list[uuid.UUID], @@ -2295,33 +2415,38 @@ async def _create_observation_directly( mentioned_at: datetime | None = None, perf: ConsolidationPerfLog | None = None, ) -> dict[str, Any]: - """Create an observation from one or more source memories with pre-processed text.""" - live_source_memory_ids = await _filter_live_source_memories(conn, bank_id, source_memory_ids) - if not live_source_memory_ids: - logger.debug(f"Create skipped: all {len(source_memory_ids)} source memories were deleted concurrently") - return {"action": "skipped", "reason": "sources_deleted"} - source_memory_ids = live_source_memory_ids - - # Generate embedding for the observation (convert to string for pgvector) + """Create an observation from one or more source memories with pre-processed text. + + The embedding is computed off-connection (a slow embedder must never pin a pooled + connection); the liveness check + INSERT + observation_sources insert then run in one + short transaction so they commit atomically. + """ + # Preflight (non-locking, separate short-lived conn): if every source memory is already + # gone, skip BEFORE the slow embed — restores the pre-refactor short-circuit so a no-op + # create doesn't embed and a failing embedder doesn't raise where it used to skip. + async with acquire_with_retry(pool) as conn: + if not await _any_live_source_memory(conn, bank_id, source_memory_ids): + logger.debug(f"Create skipped: all {len(source_memory_ids)} source memories were deleted before embedding") + return {"action": "skipped", "reason": "sources_deleted"} + + # Generate embedding for the observation (convert to string for pgvector) BEFORE + # acquiring a connection so the embedder never holds a pooled connection. t0 = time.time() embeddings = await embedding_utils.generate_embeddings_batch(memory_engine.embeddings, [observation_text]) embedding_str = str(embeddings[0]) if embeddings else None if perf: perf.record_timing("embedding", time.time() - t0) - # Create the observation as a memory_unit now = datetime.now(timezone.utc) obs_event_date = event_date or now obs_occurred_start = occurred_start obs_occurred_end = occurred_end obs_mentioned_at = mentioned_at or now obs_tags = tags or [] - - t0 = time.time() observation_id = uuid.uuid4() - # Query varies based on text search backend config = get_config() + # Query varies based on text search backend if config.text_search_extension == "vchord": # VectorChord: manually tokenize and insert search_vector query = f""" @@ -2352,34 +2477,44 @@ async def _create_observation_directly( RETURNING id """ - row = await conn.fetchrow( - query, - observation_id, - bank_id, - observation_text, - embedding_str, - source_memory_ids, - obs_tags, - obs_event_date, - obs_occurred_start, - obs_occurred_end, - obs_mentioned_at, - ) + async with acquire_with_retry(pool) as conn: + async with conn.transaction(): + # FOR SHARE liveness + INSERT share one tiny transaction so a concurrent + # delete cannot orphan the new observation between the check and the insert. + live_source_memory_ids = await _filter_live_source_memories(conn, bank_id, source_memory_ids) + if not live_source_memory_ids: + logger.debug(f"Create skipped: all {len(source_memory_ids)} source memories were deleted concurrently") + return {"action": "skipped", "reason": "sources_deleted"} + source_memory_ids = live_source_memory_ids - # Populate observation_sources junction table (Oracle only — PG uses native array ops). - if memory_engine._backend.ops.uses_observation_sources_table and source_memory_ids: - await conn.executemany( - f""" - INSERT INTO {fq_table("observation_sources")} (observation_id, source_id) - VALUES ($1, $2) - ON CONFLICT (observation_id, source_id) DO NOTHING - """, - [(observation_id, sid) for sid in dict.fromkeys(source_memory_ids)], - ) + t0 = time.time() + row = await conn.fetchrow( + query, + observation_id, + bank_id, + observation_text, + embedding_str, + source_memory_ids, + obs_tags, + obs_event_date, + obs_occurred_start, + obs_occurred_end, + obs_mentioned_at, + ) - if perf: - perf.record_timing("db_write", time.time() - t0) + # Populate observation_sources junction table (Oracle only — PG uses native array ops). + if memory_engine._backend.ops.uses_observation_sources_table and source_memory_ids: + await conn.executemany( + f""" + INSERT INTO {fq_table("observation_sources")} (observation_id, source_id) + VALUES ($1, $2) + ON CONFLICT (observation_id, source_id) DO NOTHING + """, + [(observation_id, sid) for sid in dict.fromkeys(source_memory_ids)], + ) - logger.debug(f"Created observation {observation_id} from {len(source_memory_ids)} memories (tags: {obs_tags})") + if perf: + perf.record_timing("db_write", time.time() - t0) + logger.debug(f"Created observation {observation_id} from {len(source_memory_ids)} memories (tags: {obs_tags})") return {"action": "created", "observation_id": str(row["id"]), "tags": obs_tags} diff --git a/hindsight-api-slim/hindsight_api/engine/entity_resolver.py b/hindsight-api-slim/hindsight_api/engine/entity_resolver.py index 6c866e973..980d39a1c 100644 --- a/hindsight-api-slim/hindsight_api/engine/entity_resolver.py +++ b/hindsight-api-slim/hindsight_api/engine/entity_resolver.py @@ -8,6 +8,7 @@ import asyncio import json import logging +import uuid from collections import defaultdict from dataclasses import dataclass, field from datetime import UTC, datetime @@ -230,7 +231,7 @@ async def resolve_entities_batch( unit_event_date, conn=None, entity_labels: list | None = None, - ) -> list[str]: + ) -> list[uuid.UUID]: """ Resolve multiple entities in batch (MUCH faster than sequential). @@ -271,7 +272,7 @@ async def _resolve_entities_batch_impl( unit_event_date, taxonomy_lookup: set[str] | None = None, labels_cfg=None, - ) -> list[str]: + ) -> list[uuid.UUID]: if self.entity_lookup == "trigram": # Route to backend-specific fuzzy strategy. # Non-PG backends (Oracle) use UTL_MATCH instead of pg_trgm. @@ -311,7 +312,7 @@ async def _resolve_entities_batch_full( unit_event_date, taxonomy_lookup: set[str] | None = None, labels_cfg=None, - ) -> list[str]: + ) -> list[uuid.UUID]: """Original strategy: load all bank entities then match in Python.""" # Query ALL candidates for this bank all_entities = await conn.fetch( @@ -395,7 +396,7 @@ async def _resolve_entities_batch_trigram( unit_event_date, taxonomy_lookup: set[str] | None = None, labels_cfg=None, - ) -> list[str]: + ) -> list[uuid.UUID]: """ Trigram strategy: fetch only similar candidates per entity name using pg_trgm. @@ -499,7 +500,7 @@ async def _resolve_entities_batch_oracle_fuzzy( unit_event_date: datetime | None, taxonomy_lookup: set[str] | None = None, labels_cfg=None, - ) -> list[str]: + ) -> list[uuid.UUID]: """ Oracle strategy: fetch similar candidates using UTL_MATCH.JARO_WINKLER_SIMILARITY. @@ -607,7 +608,7 @@ async def _resolve_from_candidates( cooccurrence_map: dict[str, set[str]], taxonomy_lookup: set[str] | None = None, labels_cfg=None, - ) -> list[str]: + ) -> list[uuid.UUID]: """Shared scoring + upsert logic used by both lookup strategies.""" # Resolve each entity using pre-fetched candidates diff --git a/hindsight-api-slim/hindsight_api/engine/memory_engine.py b/hindsight-api-slim/hindsight_api/engine/memory_engine.py index 848236ae7..4ef8cc6ae 100644 --- a/hindsight-api-slim/hindsight_api/engine/memory_engine.py +++ b/hindsight-api-slim/hindsight_api/engine/memory_engine.py @@ -45,7 +45,7 @@ from ..worker.stage import set_stage from .audit import AuditLogger, audit_context from .bank_stats_cache import BankStatsCache -from .db import DatabaseBackend, create_database_backend +from .db import DatabaseBackend, ResultRow, create_database_backend from .db_budget import budgeted_operation from .llm_interface import ProviderRateLimitResetError from .llm_trace import ( @@ -837,6 +837,41 @@ def _overlay_bank_config_disposition_mission( return ResolvedDispositionMission(disposition=resolved_disposition, mission=resolved_mission) +@dataclass +class _MemoryEditPlan: + """Inputs for the edit path of update_memory_unit, carried from the read/resolve + phase to the short write transaction so the embedding is computed off-connection.""" + + new_text: str + new_context: str | None + new_fact: str + new_occ_start: datetime | None + new_occ_end: datetime | None + new_event_date: datetime | None + # Entity ids resolved for the unit when ``entities`` is being changed; None when the + # edit leaves the unit's entity set untouched. + resolved_for_unit: list[uuid.UUID] | None + entity_date: datetime | None + mentioned_at: datetime | None + # Canonical entity names the embedding was built from, used to detect a concurrent + # entity-only edit when the row is re-locked in the write transaction. + names: list[str] + # Phase-1 snapshot Record of the row's editable columns (text/context/fact_type/event_date/ + # occurred_start/occurred_end/mentioned_at). Re-locked and compared in the write transaction + # to abort if a concurrent edit changed any of them while the embedder ran off-connection. + live_row: ResultRow + embedding: str | None = None + + +@dataclass +class _MemoryRevertPlan: + """Inputs for the revert path of update_memory_unit (see _MemoryEditPlan).""" + + arch_row: ResultRow | None # row snapshot (text/occurred_*/mentioned_at/entity_ids), or None + names: list[str] + embedding: str | None = None + + class MemoryEngine(MemoryEngineInterface): """ Advanced memory system using temporal and semantic linking with PostgreSQL. @@ -6398,10 +6433,22 @@ def _parse_edit_date(value: str | None) -> datetime | None: need_consolidation = False need_graph = False - found = False + entities_maybe_committed = False # resolve_entities_only may autocommit entities off-txn (Phase 1) + edit_relinked = False # the Phase-2 edit branch ran its writes (set in-txn) + phase2_committed = False # the Phase-2 write transaction committed without raising + + # -- Phase 1: read current state + compute embeddings OFF any write + # transaction. A slow embedder must never pin a pooled connection, so all + # embed work happens here, between two short-lived connections. Entity + # resolution (idempotent find-or-create) also runs here; the canonical + # names it yields feed the embedding. The authoritative writes happen in + # the Phase-2 transaction, which re-locks the row and applies the + # precomputed embedding + resolved entity set atomically. + edit_plan: _MemoryEditPlan | None = None + revert_plan: _MemoryRevertPlan | None = None - async with acquire_with_retry(backend) as conn: - async with conn.transaction(): + try: + async with acquire_with_retry(backend) as conn: live = await conn.fetchrow( f"SELECT text, context, fact_type, event_date, occurred_start, occurred_end, mentioned_at " f"FROM {mu} WHERE id = $1 AND bank_id = $2", @@ -6418,7 +6465,6 @@ def _parse_edit_date(value: str | None) -> datetime | None: record = live or archived if record is None: return None - found = True current_fact_type = record["fact_type"] if current_fact_type not in ("experience", "world"): raise ValueError( @@ -6426,15 +6472,6 @@ def _parse_edit_date(value: str | None) -> datetime | None: "curated. Observations are derived and regenerate from their sources." ) - collist = await self._memory_unit_columns(conn) - # The archive is cold storage, never a recall surface, so the schema gives it - # no `embedding` column at all (dropped in d4f6a8c2e1b3). The move in/out is - # therefore over every memory_units column EXCEPT embedding; on revert the - # embedding is recomputed from the unit's text/dates/entities below. This makes - # a model switch (which re-dimensions memory_units) structurally unable to trip - # a vector-dimension mismatch on the INSERT … SELECT round-trip (#2209). - arch_cols = ", ".join(c for c in (s.strip() for s in collist.split(",")) if c != '"embedding"') - # --- Edit fields (live rows only): text / context / dates / fact_type / entities --- doing_edit = any( v is not None for v in (text, context, occurred_start, occurred_end, new_fact_type) @@ -6453,13 +6490,19 @@ def _parse_edit_date(value: str | None) -> datetime | None: # tracks the occurred start when it's set. new_event_date = new_occ_start or live["event_date"] - # Rebuild the unit's entity set FIRST, so the re-embed below picks - # up the corrected canonical names. Reuses retain's resolver - # (find-or-create + cooccurrence) rather than touching entities - # directly. Orphaned entities + stale cooccurrence are swept by - # the graph-maintenance run this edit submits. + # Resolve the corrected entity set (find-or-create + cooccurrence) so + # the re-embed picks up the canonical names. resolve_entities_only is + # idempotent and designed to run outside the write txn; the Phase-2 + # relink writes exactly this resolved set, keeping the stored + # embedding consistent with the linked entities. + resolved_for_unit: list[uuid.UUID] | None = None + entity_date = None if new_entities is not None: entity_date = new_occ_start or live["mentioned_at"] + # resolve_entities_only autocommits new entities OUTSIDE the Phase-2 txn. Set + # this before the call so the failure-path cleanup also covers the resolver + # itself raising after autocommitting some entities. + entities_maybe_committed = True _resolved_ids, _e2u, unit_to_entity_ids = await resolve_entities_only( self.entity_resolver, conn, @@ -6471,163 +6514,394 @@ def _parse_edit_date(value: str | None) -> datetime | None: [[{"text": name, "type": "CONCEPT"} for name in new_entities]], entity_labels=entity_labels, ) - await conn.execute(f"DELETE FROM {ue} WHERE unit_id = $1", str(memory_uuid)) resolved_for_unit = unit_to_entity_ids.get(str(memory_uuid), []) - if resolved_for_unit: - await self.entity_resolver.link_units_to_entities_batch( - [(str(memory_uuid), eid, entity_date) for eid in resolved_for_unit], - conn=conn, + name_rows = ( + await conn.fetch( + f"SELECT canonical_name FROM {ent} WHERE id = ANY($1::uuid[]) AND bank_id = $2 ORDER BY id", + resolved_for_unit, + bank_id, ) - - ent_rows = await conn.fetch( - f"SELECT e.canonical_name FROM {ue} ue JOIN {ent} e ON ue.entity_id = e.id " - f"WHERE ue.unit_id = $1", - str(memory_uuid), - ) - new_emb = await self._reembed_memory_text( - text=new_text, - occurred_start=new_occ_start, - occurred_end=new_occ_end, + if resolved_for_unit + else [] + ) + else: + name_rows = await conn.fetch( + f"SELECT e.canonical_name FROM {ue} ue JOIN {ent} e ON ue.entity_id = e.id " + f"WHERE ue.unit_id = $1 ORDER BY e.id", + str(memory_uuid), + ) + edit_plan = _MemoryEditPlan( + new_text=new_text, + new_context=new_context, + new_fact=new_fact, + new_occ_start=new_occ_start, + new_occ_end=new_occ_end, + new_event_date=new_event_date, + resolved_for_unit=resolved_for_unit, + entity_date=entity_date, mentioned_at=live["mentioned_at"], - entities=[r["canonical_name"] for r in ent_rows], - ) - await enqueue_relink_victims(conn, bank_id, [memory_id], ops=backend.ops) - await conn.execute( - f""" - UPDATE {mu} - SET text = $3, context = $4, fact_type = $5, occurred_start = $6, - occurred_end = $7, event_date = $8, embedding = $9::vector, - consolidated_at = NULL, consolidation_failed_at = NULL, - edited_at = now(), updated_at = now() - WHERE id = $1 AND bank_id = $2 - """, - str(memory_uuid), - bank_id, - new_text, - new_context, - new_fact, - new_occ_start, - new_occ_end, - new_event_date, - new_emb, - ) - await conn.execute(f"DELETE FROM {ml} WHERE from_unit_id = $1 OR to_unit_id = $1", str(memory_uuid)) - await self._delete_stale_observations_for_memories(conn, bank_id, [memory_id]) - need_consolidation = True - need_graph = True - - # --- Invalidate: move live → archive --- - if state == "invalidated" and live: - entity_ids = [ - r["entity_id"] - for r in await conn.fetch(f"SELECT entity_id FROM {ue} WHERE unit_id = $1", str(memory_uuid)) - ] - # Capture relink victims BEFORE the row (and its links) disappear. - await enqueue_relink_victims(conn, bank_id, [memory_id], ops=backend.ops) - await conn.execute( - f"INSERT INTO {arch} ({arch_cols}, invalidation_reason, invalidated_at, entity_ids) " - f"SELECT {arch_cols}, $2, now(), $3::uuid[] FROM {mu} WHERE id = $1 AND bank_id = $4", - str(memory_uuid), - reason, - entity_ids, - bank_id, - ) - # Cascade prunes unit_entities + memory_links; sweep runs after - # the delete so it also catches a racing observation insert. - await conn.execute(f"DELETE FROM {mu} WHERE id = $1 AND bank_id = $2", str(memory_uuid), bank_id) - await self._delete_stale_observations_for_memories(conn, bank_id, [memory_id]) - need_consolidation = True - need_graph = True - elif state == "invalidated" and archived and reason is not None: - # Already archived — just update the recorded reason. - await conn.execute( - f"UPDATE {arch} SET invalidation_reason = $3 WHERE id = $1 AND bank_id = $2", - str(memory_uuid), - bank_id, - reason, + names=[r["canonical_name"] for r in name_rows], + live_row=live, ) - # --- Revert: move archive → live --- + # --- Revert prep (archived rows): gather the archive snapshot the + # re-embed needs. The move copies the archive row verbatim minus the + # embedding, so its text/dates ARE the reverted values. --- elif state == "valid" and archived: arch_row = await conn.fetchrow( - f"SELECT entity_ids FROM {arch} WHERE id = $1 AND bank_id = $2", str(memory_uuid), bank_id - ) - # The archive has no embedding column (see arch_cols above), so the live - # row's embedding defaults to NULL on the way back and is recomputed below - # once entities are restored. - await conn.execute( - f"INSERT INTO {mu} ({arch_cols}) SELECT {arch_cols} FROM {arch} WHERE id = $1 AND bank_id = $2", + f"SELECT text, occurred_start, occurred_end, mentioned_at, entity_ids " + f"FROM {arch} WHERE id = $1 AND bank_id = $2", str(memory_uuid), bank_id, ) - # Re-consolidate from scratch; links are rebuilt by graph maintenance. - await conn.execute( - f"UPDATE {mu} SET consolidated_at = NULL, consolidation_failed_at = NULL, updated_at = now() " - f"WHERE id = $1 AND bank_id = $2", + rev_entity_ids = list(arch_row["entity_ids"]) if (arch_row and arch_row["entity_ids"]) else [] + rev_name_rows = ( + await conn.fetch( + f"SELECT canonical_name FROM {ent} WHERE id = ANY($1::uuid[]) AND bank_id = $2 ORDER BY id", + rev_entity_ids, + bank_id, + ) + if rev_entity_ids + else [] + ) + revert_plan = _MemoryRevertPlan( + arch_row=arch_row, + names=[r["canonical_name"] for r in rev_name_rows], + ) + + # -- Embed OFF any connection -- + if edit_plan is not None: + edit_plan.embedding = await self._reembed_memory_text( + text=edit_plan.new_text, + occurred_start=edit_plan.new_occ_start, + occurred_end=edit_plan.new_occ_end, + mentioned_at=edit_plan.mentioned_at, + entities=edit_plan.names, + ) + if revert_plan is not None and revert_plan.arch_row is not None: + ar = revert_plan.arch_row + revert_plan.embedding = await self._reembed_memory_text( + text=ar["text"], + occurred_start=ar["occurred_start"], + occurred_end=ar["occurred_end"], + mentioned_at=ar["mentioned_at"], + entities=revert_plan.names, + ) + + # -- Phase 2: short write transaction -- all visible mutations atomic -- + async with acquire_with_retry(backend) as conn: + async with conn.transaction(): + # Re-lock the target under the write txn. Moving the embed out + # widened the stale-snapshot window, so we revalidate existence + # here and skip cleanly if the row was concurrently moved/deleted. + live2 = await conn.fetchrow( + f"SELECT text, context, fact_type, event_date, occurred_start, occurred_end, mentioned_at " + f"FROM {mu} WHERE id = $1 AND bank_id = $2 FOR UPDATE", str(memory_uuid), bank_id, ) - # Restore entity associations for entities that still exist (some may - # have been pruned as orphans after the original move). - if arch_row and arch_row["entity_ids"]: + archived2 = None + if not live2: + archived2 = await conn.fetchrow( + f"SELECT text, occurred_start, occurred_end, mentioned_at, entity_ids " + f"FROM {arch} WHERE id = $1 AND bank_id = $2 FOR UPDATE", + str(memory_uuid), + bank_id, + ) + + # Column list for the archive <-> live moves (only invalidate/revert need it). + arch_cols = None + if state in ("invalidated", "valid") and (live2 or archived2): + collist = await self._memory_unit_columns(conn) + arch_cols = ", ".join(c for c in (s.strip() for s in collist.split(",")) if c != '"embedding"') + + # --- Apply edit (live rows only) --- + if edit_plan is not None and live2: + # Abort if a concurrent edit changed any embedding-/resolution-input column while + # the embedder ran off-connection: applying the precomputed embedding + resolved + # entities would clobber the concurrent writer with stale-derived data. Roll back + # the txn for a cheap retry. Background paths never touch these columns + # (consolidation writes only consolidated_at; graph maintenance only entities), + # so only a competing edit trips this. + snap = edit_plan.live_row + if ( + live2["text"], + live2["context"], + live2["fact_type"], + live2["event_date"], + live2["occurred_start"], + live2["occurred_end"], + live2["mentioned_at"], + ) != ( + snap["text"], + snap["context"], + snap["fact_type"], + snap["event_date"], + snap["occurred_start"], + snap["occurred_end"], + snap["mentioned_at"], + ): + raise RuntimeError( + f"update_memory_unit: memory {memory_id} was modified concurrently " + f"between read and write; retry the edit" + ) + edit_embedding = edit_plan.embedding + if edit_plan.resolved_for_unit is not None: + # Entities are being changed: rebuild unit_entities to the resolved set. + # Phase 1 resolved (and autocommitted) these entities OFF any txn, so until + # we link them they are committed orphans a concurrent graph-maintenance + # prune can delete. link_units_to_entities_batch is a plain FK INSERT, so a + # prune between a non-locking check and the insert would FK-fail. Lock the + # resolved rows FOR UPDATE (the prune's DELETE then blocks until we commit, + # by which point the link exists and the row is no longer an orphan) and link + # EXACTLY the locked ids. ORDER BY id keeps a stable lock order. + await conn.execute(f"DELETE FROM {ue} WHERE unit_id = $1", str(memory_uuid)) + link_ids = list(edit_plan.resolved_for_unit) + if link_ids: + locked = await conn.fetch( + f"SELECT id, canonical_name FROM {ent} " + f"WHERE id = ANY($1::uuid[]) AND bank_id = $2 ORDER BY id FOR UPDATE", + link_ids, + bank_id, + ) + if {r["id"] for r in locked} != set(link_ids): + # Some resolved entities were pruned before we locked them. Re-resolve + # under the lock (rows recreated here live in this txn, unprunable), + # re-lock the fresh id set, and re-embed from the locked names. + assert new_entities is not None # set whenever resolved_for_unit is not None + _ri, _e2u, u2e = await resolve_entities_only( + self.entity_resolver, + conn, + bank_id, + [str(memory_uuid)], + [edit_plan.new_text], + edit_plan.new_context or "", + [edit_plan.entity_date], + [[{"text": name, "type": "CONCEPT"} for name in new_entities]], + entity_labels=entity_labels, + ) + link_ids = u2e.get(str(memory_uuid), []) + locked = ( + await conn.fetch( + f"SELECT id, canonical_name FROM {ent} " + f"WHERE id = ANY($1::uuid[]) AND bank_id = $2 ORDER BY id FOR UPDATE", + link_ids, + bank_id, + ) + if link_ids + else [] + ) + if {r["id"] for r in locked} != set(link_ids): + # A found-existing entity was pruned again in the sub-statement + # gap before this second lock. Abort rather than commit a partial + # set: the txn rolls back and the finally-block sweep reclaims + # orphans; the caller can retry. + raise RuntimeError( + f"update_memory_unit: entity set for {memory_id} changed under " + f"concurrent graph maintenance; retry the edit" + ) + edit_embedding = await self._reembed_memory_text( + text=edit_plan.new_text, + occurred_start=edit_plan.new_occ_start, + occurred_end=edit_plan.new_occ_end, + mentioned_at=edit_plan.mentioned_at, + entities=[r["canonical_name"] for r in locked], + ) + locked_ids = [r["id"] for r in locked] + if locked_ids: + await self.entity_resolver.link_units_to_entities_batch( + [(str(memory_uuid), eid, edit_plan.entity_date) for eid in locked_ids], + conn=conn, + ) + else: + # Entities are NOT changing here, so the precomputed embedding used the + # unit's Phase-1 entity names. Re-read them under the lock: a concurrent + # entity-only edit between the phases could have changed the set, which + # would leave the embedding naming stale entities. Only on that (rare) + # mismatch do we re-embed in-txn, keeping the stored embedding consistent + # with the committed unit_entities. + locked_rows = await conn.fetch( + f"SELECT e.canonical_name FROM {ue} ue JOIN {ent} e ON ue.entity_id = e.id " + f"WHERE ue.unit_id = $1 ORDER BY e.id", + str(memory_uuid), + ) + locked_names = [r["canonical_name"] for r in locked_rows] + if locked_names != edit_plan.names: + edit_embedding = await self._reembed_memory_text( + text=edit_plan.new_text, + occurred_start=edit_plan.new_occ_start, + occurred_end=edit_plan.new_occ_end, + mentioned_at=edit_plan.mentioned_at, + entities=locked_names, + ) + await enqueue_relink_victims(conn, bank_id, [memory_id], ops=backend.ops) await conn.execute( - f"INSERT INTO {ue} (unit_id, entity_id) " - f"SELECT $1, eid FROM unnest($2::uuid[]) AS eid " - f"WHERE EXISTS (SELECT 1 FROM {ent} e WHERE e.id = eid AND e.bank_id = $3) " - f"ON CONFLICT DO NOTHING", + f""" + UPDATE {mu} + SET text = $3, context = $4, fact_type = $5, occurred_start = $6, + occurred_end = $7, event_date = $8, embedding = $9::vector, + consolidated_at = NULL, consolidation_failed_at = NULL, + edited_at = now(), updated_at = now() + WHERE id = $1 AND bank_id = $2 + """, str(memory_uuid), - arch_row["entity_ids"], bank_id, + edit_plan.new_text, + edit_plan.new_context, + edit_plan.new_fact, + edit_plan.new_occ_start, + edit_plan.new_occ_end, + edit_plan.new_event_date, + edit_embedding, ) - # Recompute the embedding (the archive doesn't keep one) so the reverted - # unit is searchable again, using the now-current model's dimension and the - # restored entity set — mirroring how an edit re-embeds. - reverted = await conn.fetchrow( - f"SELECT text, occurred_start, occurred_end, mentioned_at FROM {mu} " - f"WHERE id = $1 AND bank_id = $2", - str(memory_uuid), - bank_id, - ) - if reverted: - ent_rows = await conn.fetch( - f"SELECT e.canonical_name FROM {ue} ue JOIN {ent} e ON ue.entity_id = e.id " - f"WHERE ue.unit_id = $1", + await conn.execute( + f"DELETE FROM {ml} WHERE from_unit_id = $1 OR to_unit_id = $1", str(memory_uuid) + ) + await self._delete_stale_observations_for_memories(conn, bank_id, [memory_id]) + need_consolidation = True + need_graph = True + edit_relinked = True + + # --- Invalidate: move live -> archive --- + if state == "invalidated" and live2: + entity_ids = [ + r["entity_id"] + for r in await conn.fetch( + f"SELECT entity_id FROM {ue} WHERE unit_id = $1", str(memory_uuid) + ) + ] + # Capture relink victims BEFORE the row (and its links) disappear. + await enqueue_relink_victims(conn, bank_id, [memory_id], ops=backend.ops) + await conn.execute( + f"INSERT INTO {arch} ({arch_cols}, invalidation_reason, invalidated_at, entity_ids) " + f"SELECT {arch_cols}, $2, now(), $3::uuid[] FROM {mu} WHERE id = $1 AND bank_id = $4", str(memory_uuid), + reason, + entity_ids, + bank_id, ) - new_emb = await self._reembed_memory_text( - text=reverted["text"], - occurred_start=reverted["occurred_start"], - occurred_end=reverted["occurred_end"], - mentioned_at=reverted["mentioned_at"], - entities=[r["canonical_name"] for r in ent_rows], + # Cascade prunes unit_entities + memory_links; sweep runs after + # the delete so it also catches a racing observation insert. + await conn.execute( + f"DELETE FROM {mu} WHERE id = $1 AND bank_id = $2", str(memory_uuid), bank_id ) - if new_emb is not None: + await self._delete_stale_observations_for_memories(conn, bank_id, [memory_id]) + need_consolidation = True + need_graph = True + elif state == "invalidated" and archived2 and reason is not None: + # Already archived — just update the recorded reason. + await conn.execute( + f"UPDATE {arch} SET invalidation_reason = $3 WHERE id = $1 AND bank_id = $2", + str(memory_uuid), + bank_id, + reason, + ) + + # --- Revert: move archive -> live --- + elif state == "valid" and archived2 and revert_plan is not None: + # The archive has no embedding column (see arch_cols above), so the + # live row's embedding defaults to NULL on the way back and is set + # below from the embedding computed off-connection in Phase 1. + await conn.execute( + f"INSERT INTO {mu} ({arch_cols}) SELECT {arch_cols} FROM {arch} WHERE id = $1 AND bank_id = $2", + str(memory_uuid), + bank_id, + ) + # Re-consolidate from scratch; links are rebuilt by graph maintenance. + await conn.execute( + f"UPDATE {mu} SET consolidated_at = NULL, consolidation_failed_at = NULL, updated_at = now() " + f"WHERE id = $1 AND bank_id = $2", + str(memory_uuid), + bank_id, + ) + # Restore entity associations for entities that still exist, from the LOCKED + # archive row (some may have been pruned as orphans after the original move). + locked_entity_ids = list(archived2["entity_ids"]) if archived2["entity_ids"] else [] + if locked_entity_ids: + await conn.execute( + f"INSERT INTO {ue} (unit_id, entity_id) " + f"SELECT $1, eid FROM unnest($2::uuid[]) AS eid " + f"WHERE EXISTS (SELECT 1 FROM {ent} e WHERE e.id = eid AND e.bank_id = $3) " + f"ON CONFLICT DO NOTHING", + str(memory_uuid), + locked_entity_ids, + bank_id, + ) + # The Phase-1 embedding was computed from the Phase-1 archive snapshot. If the + # archive row was rewritten (text/dates) or its surviving entity set differs while + # the embedder ran off-connection, recompute under the lock so the stored vector + # matches the restored (locked) row. + revert_embedding = revert_plan.embedding + linked = await conn.fetch( + f"SELECT e.canonical_name FROM {ue} ue JOIN {ent} e ON ue.entity_id = e.id " + f"WHERE ue.unit_id = $1 ORDER BY e.id", + str(memory_uuid), + ) + linked_names = [r["canonical_name"] for r in linked] + snap = revert_plan.arch_row + if ( + snap is None + or ( + archived2["text"], + archived2["occurred_start"], + archived2["occurred_end"], + archived2["mentioned_at"], + ) + != (snap["text"], snap["occurred_start"], snap["occurred_end"], snap["mentioned_at"]) + or linked_names != revert_plan.names + ): + revert_embedding = await self._reembed_memory_text( + text=archived2["text"], + occurred_start=archived2["occurred_start"], + occurred_end=archived2["occurred_end"], + mentioned_at=archived2["mentioned_at"], + entities=linked_names, + ) + if revert_embedding is not None: await conn.execute( f"UPDATE {mu} SET embedding = $3::vector WHERE id = $1 AND bank_id = $2", str(memory_uuid), bank_id, - new_emb, + revert_embedding, ) - await conn.execute(f"DELETE FROM {arch} WHERE id = $1 AND bank_id = $2", str(memory_uuid), bank_id) - need_consolidation = True - need_graph = True - - if not found: - return None - - if need_consolidation: - config = await self._config_resolver.resolve_full_config(bank_id, request_context) - if config.enable_auto_consolidation: + await conn.execute( + f"DELETE FROM {arch} WHERE id = $1 AND bank_id = $2", str(memory_uuid), bank_id + ) + need_consolidation = True + need_graph = True + phase2_committed = True # reached only if the Phase-2 txn committed + finally: + # An entity-changing edit resolved (and may have autocommitted) entities in Phase 1, but the + # edit did not durably apply: resolve/embed/Phase 2 raised, the row was concurrently + # invalidated (live2 None), or a combined edit+invalidate rolled back. Those entities are now + # orphaned; submit_async_graph_maintenance short-circuits on an empty queue and the edit's own + # relink-victim enqueue never ran, so enqueue this unit explicitly to force the bank-wide + # orphan-entity prune to run. + force_graph_submit = False + if entities_maybe_committed and not (edit_relinked and phase2_committed): try: - await self.submit_async_consolidation(bank_id=bank_id, request_context=request_context) + async with acquire_with_retry(backend) as cleanup_conn: + await backend.ops.enqueue_graph_maintenance( + cleanup_conn, fq_table("graph_maintenance_queue"), bank_id, [memory_uuid] + ) + force_graph_submit = True except Exception as e: - logger.warning(f"Failed to submit consolidation after curating memory in bank {bank_id}: {e}") - if need_graph: - try: - await self.submit_async_graph_maintenance(bank_id=bank_id, request_context=request_context) - except Exception as e: - logger.warning(f"Failed to submit graph maintenance after curating memory in bank {bank_id}: {e}") - + logger.warning( + f"Failed to enqueue orphan-entity cleanup after a failed edit in bank {bank_id}: {e}" + ) + # Normal post-commit follow-ups run only when the Phase-2 txn committed. + if phase2_committed and need_consolidation: + config = await self._config_resolver.resolve_full_config(bank_id, request_context) + if config.enable_auto_consolidation: + try: + await self.submit_async_consolidation(bank_id=bank_id, request_context=request_context) + except Exception as e: + logger.warning(f"Failed to submit consolidation after curating memory in bank {bank_id}: {e}") + if (phase2_committed and need_graph) or force_graph_submit: + try: + await self.submit_async_graph_maintenance(bank_id=bank_id, request_context=request_context) + except Exception as e: + logger.warning(f"Failed to submit graph maintenance after curating memory in bank {bank_id}: {e}") return await self.get_memory_unit(bank_id=bank_id, memory_id=memory_id, request_context=request_context) async def run_consolidation( @@ -10674,6 +10948,16 @@ async def update_mental_model( await self._validate_operation(self._operation_validator.validate_bank_write(ctx)) backend = await self._get_backend() + # Compute the new embedding BEFORE acquiring a pooled connection: a slow + # embedder must never pin a DB connection. The embedding text depends only + # on the incoming name/content, never on DB state, so it can be done here. + new_embedding_str: str | None = None + if content is not None: + embedding_text = f"{name or ''} {content}" + embedding = await embedding_utils.generate_embeddings_batch(self.embeddings, [embedding_text]) + if embedding: + new_embedding_str = str(embedding[0]) + async with acquire_with_retry(backend) as conn: # If content is changing, fetch current content + reflect_response to record history previous_content: str | None = None @@ -10724,12 +11008,10 @@ async def update_mental_model( if based_on is not None: slim_reflect_response = {"based_on": based_on} record_mm_history = True - # Also update embedding (convert to string for asyncpg vector type) - embedding_text = f"{name or ''} {content}" - embedding = await embedding_utils.generate_embeddings_batch(self.embeddings, [embedding_text]) - if embedding: + # Apply the embedding computed above (off-connection). + if new_embedding_str is not None: updates.append(f"embedding = ${param_idx}") - params.append(str(embedding[0])) + params.append(new_embedding_str) param_idx += 1 if reflect_response is not None: diff --git a/hindsight-api-slim/hindsight_api/engine/retain/fact_storage.py b/hindsight-api-slim/hindsight_api/engine/retain/fact_storage.py index d9c93a0f5..070850ccd 100644 --- a/hindsight-api-slim/hindsight_api/engine/retain/fact_storage.py +++ b/hindsight-api-slim/hindsight_api/engine/retain/fact_storage.py @@ -163,34 +163,21 @@ async def ensure_bank_exists(conn, bank_id: str, ops=None) -> None: await create_bank_vector_indexes(conn, bank_id, str(internal_id), ops=ops) -async def delete_stale_observations_for_memories( +async def _snapshot_stale_observations( conn, bank_id: str, fact_ids: "list[str | uuid.UUID]", ops=None, -) -> int: - """Delete observations whose source memories are about to be removed. - - Mirrors the cleanup performed by ``MemoryEngine.delete_document`` so that - every code path that removes ``memory_units`` also removes the - observations derived from them. Without this, ingesting a fresh version - of a document via the retain pipeline (which does a full-replace - ``DELETE FROM documents`` cascade) used to leave orphan observations - pointing at memory IDs that no longer existed. - - For each observation referencing any of ``fact_ids``: - 1. Delete the observation row (its text is stale once even one source - memory disappears). - 2. Reset ``consolidated_at = NULL`` on the surviving source memories so - they get re-consolidated under fresh observations on the next run. - - Must be called within an active transaction, before the source memories - are deleted. +) -> "tuple[list, list]": + """Find observations derived from ``fact_ids`` and the surviving co-sources to reset. - Returns the number of observations deleted. + Returns ``(obs_ids, remaining_source_ids)``. Read-only: performs no writes, so it + reads the affected observation rows by id array / junction (which survive source + deletion) and may run either before or after the source memories are removed. + Returns ``([], [])`` when ``fact_ids`` is empty or no observation references them. """ if not fact_ids: - return 0 + return [], [] fact_uuids = [uuid.UUID(str(fid)) if not isinstance(fid, uuid.UUID) else fid for fid in fact_ids] @@ -226,7 +213,7 @@ async def delete_stale_observations_for_memories( ) if not affected_obs: - return 0 + return [], [] deleted_set = {str(uid) for uid in fact_uuids} obs_ids = [obs["id"] for obs in affected_obs] @@ -239,6 +226,22 @@ async def delete_stale_observations_for_memories( remaining_source_ids.append(src_id) seen_remaining.add(src_str) + return obs_ids, remaining_source_ids + + +async def _apply_stale_observation_deletion( + conn, + bank_id: str, + obs_ids: list, + remaining_source_ids: list, +) -> int: + """Delete the snapshotted observation rows and reset surviving co-sources. + + Returns the number of observations deleted (0 when ``obs_ids`` is empty). + """ + if not obs_ids: + return 0 + await conn.execute( f"DELETE FROM {fq_table('memory_units')} WHERE id = ANY($1::uuid[])", obs_ids, @@ -262,6 +265,40 @@ async def delete_stale_observations_for_memories( return len(obs_ids) +async def delete_stale_observations_for_memories( + conn, + bank_id: str, + fact_ids: "list[str | uuid.UUID]", + ops=None, +) -> int: + """Delete observations whose source memories are about to be removed. + + Mirrors the cleanup performed by ``MemoryEngine.delete_document`` so that + every code path that removes ``memory_units`` also removes the + observations derived from them. Without this, ingesting a fresh version + of a document via the retain pipeline (which does a full-replace + ``DELETE FROM documents`` cascade) used to leave orphan observations + pointing at memory IDs that no longer existed. + + For each observation referencing any of ``fact_ids``: + 1. Delete the observation row (its text is stale once even one source + memory disappears). + 2. Reset ``consolidated_at = NULL`` on the surviving source memories so + they get re-consolidated under fresh observations on the next run. + + Must be called within an active transaction. Order-independent w.r.t. the source + delete: it snapshots the affected observations (by id array / junction, which + survive source deletion) before writing, so it is safe to invoke either before or + after the source memories are removed. + + Returns the number of observations deleted. + """ + obs_ids, remaining_source_ids = await _snapshot_stale_observations(conn, bank_id, fact_ids, ops=ops) + if not obs_ids: + return 0 + return await _apply_stale_observation_deletion(conn, bank_id, obs_ids, remaining_source_ids) + + async def handle_document_tracking( conn, bank_id: str, @@ -300,11 +337,13 @@ async def handle_document_tracking( # Delete old document first (cascades to units and links). # Only delete on the first batch to avoid deleting data we just inserted. - # Before the cascade, fan out to delete observations derived from the - # outgoing memory_units — otherwise the FK ON DELETE CASCADE removes the - # source memory_units but leaves observation rows pointing at IDs that - # no longer exist (consolidated_at on co-source memories also stays - # frozen). Same cleanup the explicit ``delete_document`` API performs. + # Snapshot the observations derived from the outgoing memory_units BEFORE the + # delete, then remove those observations (and reset co-source consolidated_at) + # AFTER the sources are gone. Otherwise the FK ON DELETE CASCADE removes the + # source memory_units but leaves observation rows pointing at IDs that no longer + # exist (and co-source consolidated_at stays frozen). Same cleanup the explicit + # ``delete_document`` API performs; sources are deleted first to keep the + # SOURCE -> OBSERVATION lock order. preserved_created_at = None if is_first_batch: existing_unit_rows = await conn.fetch( @@ -315,32 +354,41 @@ async def handle_document_tracking( document_id, ) existing_unit_ids = [row["id"] for row in existing_unit_rows] + obs_ids: list = [] + remaining_source_ids: list = [] if existing_unit_ids: - invalidated = await delete_stale_observations_for_memories(conn, bank_id, existing_unit_ids, ops=ops) - if invalidated: - logger.info( - f"[RETAIN] Document {document_id} re-ingested: invalidated " - f"{invalidated} observation(s) derived from {len(existing_unit_ids)} outgoing memory_units" - ) + # Snapshot affected observations BEFORE deleting sources (the PG array / + # Oracle junction are read here while still intact); apply the deletion + # AFTER the source delete so the lock order is SOURCE -> OBSERVATION, + # matching consolidation/invalidation and avoiding a cross-order deadlock. + obs_ids, remaining_source_ids = await _snapshot_stale_observations( + conn, bank_id, existing_unit_ids, ops=ops + ) # Capture link-recompute victims BEFORE the cascade. Same staleness # applies on upsert as on explicit delete: surviving units in OTHER - # documents that linked to these doomed units are about to lose - # those links. ``ops`` may be None for older callers that haven't - # been wired up — skip enqueue in that case rather than crash. + # documents that linked to these doomed units are about to lose those + # links. ``ops`` may be None for older callers — skip enqueue in that case. if ops is not None: from ..graph_maintenance import enqueue_relink_victims await enqueue_relink_victims(conn, bank_id, [str(uid) for uid in existing_unit_ids], ops=ops) - # Explicitly delete memory_units by document_id BEFORE deleting the - # document row. The CASCADE from documents→chunks→memory_units only - # catches units that have a non-NULL chunk_id FK. Units with chunk_id=NULL - # (e.g. from partial writes or edge cases) would survive the cascade. - # This explicit delete ensures complete cleanup. + # Delete source memory_units FIRST. The CASCADE from documents->chunks-> + # memory_units only catches units with a non-NULL chunk_id FK; units with + # chunk_id=NULL would survive. Deleting sources before the derived + # observations also fixes the lock order (SOURCE -> OBSERVATION). await conn.execute( f"DELETE FROM {fq_table('memory_units')} WHERE document_id = $1 AND bank_id = $2", document_id, bank_id, ) + # Then delete the affected observations + reset surviving co-sources. + if obs_ids: + invalidated = await _apply_stale_observation_deletion(conn, bank_id, obs_ids, remaining_source_ids) + if invalidated: + logger.info( + f"[RETAIN] Document {document_id} re-ingested: invalidated " + f"{invalidated} observation(s) derived from {len(existing_unit_ids)} outgoing memory_units" + ) # Capture created_at before deletion so re-ingestion preserves it. preserved_created_at = await conn.fetchval( f"DELETE FROM {fq_table('documents')} WHERE id = $1 AND bank_id = $2 RETURNING created_at", diff --git a/hindsight-api-slim/hindsight_api/engine/retain/link_utils.py b/hindsight-api-slim/hindsight_api/engine/retain/link_utils.py index ee88410dd..2c8dbb26e 100644 --- a/hindsight-api-slim/hindsight_api/engine/retain/link_utils.py +++ b/hindsight-api-slim/hindsight_api/engine/retain/link_utils.py @@ -4,6 +4,7 @@ import logging import time +import uuid from datetime import UTC, datetime, timedelta from ..._vector_index import ann_search_tuning_settings, configured_vector_extension @@ -300,7 +301,7 @@ async def resolve_entities_only( llm_entities: list[list[dict]], log_buffer: list[str] = None, entity_labels: list | None = None, -) -> tuple[list[str], list[tuple], dict[str, list[str]]]: +) -> tuple[list[uuid.UUID], list[tuple], dict[str, list[uuid.UUID]]]: """ Phase 1 of entity processing: resolve entity names to canonical IDs. @@ -350,7 +351,7 @@ async def resolve_entities_only( ) # Build unit_to_entity_ids mapping - unit_to_entity_ids: dict[str, list[str]] = {} + unit_to_entity_ids: dict[str, list[uuid.UUID]] = {} for idx, (unit_id, _local_idx, _fact_date) in enumerate(entity_to_unit): if unit_id not in unit_to_entity_ids: unit_to_entity_ids[unit_id] = [] diff --git a/hindsight-api-slim/tests/test_consolidation_dedup.py b/hindsight-api-slim/tests/test_consolidation_dedup.py index ac42678c8..d92e66e38 100644 --- a/hindsight-api-slim/tests/test_consolidation_dedup.py +++ b/hindsight-api-slim/tests/test_consolidation_dedup.py @@ -7,8 +7,9 @@ import types import uuid +from contextlib import asynccontextmanager from dataclasses import dataclass -from unittest.mock import AsyncMock, patch +from unittest.mock import DEFAULT, AsyncMock, patch from hindsight_api.engine.consolidation.consolidator import ( _dedup_active, @@ -76,12 +77,105 @@ def _obs(text: str, sim: float, oid: str = _TWIN_ID) -> RetrievalResult: return RetrievalResult(id=oid, text=text, fact_type="observation", similarity=sim) +class _DedupConn: + """Backend-shaped conn for dedup-fold tests. Enforces that the live-source filter and + the fold UPDATE run inside the fold transaction on an acquired connection, and that the + fold UPDATE is RETURNING-gated.""" + + def __init__(self): + self.active = 0 # >0 while a connection is acquired (set by _DedupBackend.acquire) + self._in_txn = False + self.fetchval_result = uuid.UUID(_TWIN_ID) # survivor id the fold "returns" + self.fetchrow_result = None # update-path source snapshot + self.live_rows = None # override liveness rows; None -> echo all source ids as live + # Modeled row text so the fold/snapshot text guards actually bite. None -> "match any" + # (keeps every pre-existing test, which never sets these, behaving as before). + self.current_twin_text = None # survivor/twin row's current text (create + update folds) + self.current_updated_text = None # updated row's current text (update snapshot + fold) + self.fetchval = AsyncMock(side_effect=self._fetchval) + self.fetch = AsyncMock(side_effect=self._fetch) + self.fetchrow = AsyncMock(side_effect=self._fetchrow) + self.execute = AsyncMock() + + @asynccontextmanager + async def transaction(self): + assert self.active > 0, "fold transaction opened without an acquired connection" + self._in_txn = True + try: + yield + finally: + self._in_txn = False + + async def _fetchval(self, query, *args): + assert self._in_txn, "fold UPDATE must run inside the fold transaction" + assert "RETURNING" in query, "fold UPDATE must be RETURNING-gated" + # Assert the text-guard CLAUSE is present (not just that the arg is passed) so deleting the SQL + # guard fails even if the param is left behind, then model the guarded row text so a stale-text + # fold matches no row. ``args`` excludes the bound ``query``. + if "u.text" in query: # update-path fold + assert "t.text = $4" in query and "u.text = $5" in query, "update fold must keep both text guards" + if self.current_twin_text is not None and args[3] != self.current_twin_text: + return None + if self.current_updated_text is not None and args[4] != self.current_updated_text: + return None + else: # create-path fold + assert "AND text = $4" in query, "create fold must keep the twin text guard (AND text = $4)" + if self.current_twin_text is not None and args[3] != self.current_twin_text: + return None + return self.fetchval_result + + async def _fetch(self, query, source_ids, bank_id): + assert self._in_txn, "live-source filter must run inside the fold transaction" + assert "FOR SHARE" in query, "live-source filter must hold FOR SHARE on the source rows" + if self.live_rows is not None: + return self.live_rows + return [{"id": s} for s in source_ids] + + async def _fetchrow(self, query, *args): + # Assert the text-guard CLAUSE is present (so deleting it fails even if the arg stays), then + # model the updated row's text so a row rewritten during the LLM window snapshots as gone. + # ``args`` excludes the bound ``query``. + assert "AND text = $2" in query, "update snapshot must keep the updated-text guard (AND text = $2)" + if self.current_updated_text is not None and args[1] != self.current_updated_text: + return None + return self.fetchrow_result + + +class _DedupBackend: + """Backend-shaped stand-in matching acquire_with_retry's ``_wraps_backend`` path.""" + + _wraps_backend = True + + def __init__(self, conn): + self._conn = conn + + @asynccontextmanager + async def acquire(self): + self._conn.active += 1 + try: + yield self._conn + finally: + self._conn.active -= 1 + + +def _make_dedup_llm(conn): + """An LLM stub that asserts no pooled connection is held when it is called.""" + llm = types.SimpleNamespace(call=AsyncMock()) + + def _assert_released(*a, **k): + assert conn.active == 0, "no pooled connection may be held during the dedup LLM call" + return DEFAULT # fall through to the llm.call.return_value the test sets + + llm.call.side_effect = _assert_released + return llm + + def _ctx(threshold: float = 0.97): """Return (kwargs, conn_mock, llm_mock) for a _dedup_reconcile_create call.""" - conn = AsyncMock() - llm = types.SimpleNamespace(call=AsyncMock()) + conn = _DedupConn() + llm = _make_dedup_llm(conn) kwargs = dict( - conn=conn, + pool=_DedupBackend(conn), memory_engine=types.SimpleNamespace(embeddings=object()), bank_id="bank1", config=types.SimpleNamespace(consolidation_dedup_threshold=threshold), @@ -113,28 +207,34 @@ async def test_dedup_no_twin_above_threshold_returns_none() -> None: result = await _dedup_reconcile_create(**kwargs) assert result is None llm.call.assert_not_called() # below threshold → no LLM call - conn.execute.assert_not_called() # no merge + conn.fetchval.assert_not_called() # no merge async def test_dedup_llm_keep_does_not_merge() -> None: kwargs, conn, llm = _ctx() llm.call.return_value = _DedupDecision(action="keep", reason="different language") - with _patch_embed(), _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.98)]): + with ( + _patch_embed(), + _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.98)]), + ): result = await _dedup_reconcile_create(**kwargs) assert result is None llm.call.assert_awaited_once() - conn.execute.assert_not_called() # kept distinct → no merge + conn.fetchval.assert_not_called() # kept distinct → no merge async def test_dedup_llm_merge_folds_into_twin() -> None: kwargs, conn, llm = _ctx() kwargs["create_source_ids"] = [uuid.uuid4(), uuid.uuid4()] llm.call.return_value = _DedupDecision(action="merge", text="Uzbek content on YouTube is very rich.") - with _patch_embed(), _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.99)]): + with ( + _patch_embed(), + _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.99)]), + ): result = await _dedup_reconcile_create(**kwargs) assert result == _TWIN_ID # merged into the twin; caller skips the CREATE - conn.execute.assert_awaited_once() - args = conn.execute.await_args.args + conn.fetchval.assert_awaited_once() + args = conn.fetchval.await_args.args assert args[1] == "Uzbek content on YouTube is very rich." # merged text persisted assert args[2] == kwargs["create_source_ids"] # new source facts folded in assert args[3] == uuid.UUID(_TWIN_ID) # onto the twin row @@ -144,7 +244,10 @@ async def test_dedup_picks_highest_above_threshold_skips_below() -> None: # Only the >=threshold candidate is considered; a 0.95 result is ignored at threshold 0.97. kwargs, conn, llm = _ctx(threshold=0.97) llm.call.return_value = _DedupDecision(action="keep") - with _patch_embed(), _patch_probe([_obs("near but distinct", 0.95), _obs("the real twin", 0.98)]): + with ( + _patch_embed(), + _patch_probe([_obs("near but distinct", 0.95), _obs("the real twin", 0.98)]), + ): await _dedup_reconcile_create(**kwargs) # the twin passed to the LLM is the >=0.97 one, not the 0.95 sent = llm.call.await_args.kwargs["messages"][0]["content"] @@ -163,10 +266,11 @@ async def test_dedup_picks_highest_above_threshold_skips_below() -> None: def _update_ctx(threshold: float = 0.97): """Return (kwargs, conn_mock, llm_mock) for a _dedup_reconcile_update call.""" - conn = AsyncMock() - llm = types.SimpleNamespace(call=AsyncMock()) + conn = _DedupConn() + conn.fetchrow_result = {"source_memory_ids": [uuid.uuid4(), uuid.uuid4()]} + llm = _make_dedup_llm(conn) kwargs = dict( - conn=conn, + pool=_DedupBackend(conn), memory_engine=types.SimpleNamespace(embeddings=object()), bank_id="bank1", config=types.SimpleNamespace(consolidation_dedup_threshold=threshold), @@ -185,13 +289,15 @@ async def test_dedup_update_merge_folds_into_twin_and_deletes_updated() -> None: with _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.98)]): await _dedup_reconcile_update(**kwargs) llm.call.assert_awaited_once() - # Two writes: fold-into-twin UPDATE, then DELETE of the updated row. - assert conn.execute.await_count == 2 - fold_args = conn.execute.await_args_list[0].args + # The fold UPDATE uses fetchval (RETURNING t.id); then the updated row is DELETEd. + conn.fetchval.assert_awaited_once() + fold_args = conn.fetchval.await_args.args assert fold_args[1] == "Uzbek YouTube content is very rich and growing." # merged text on the twin assert fold_args[2] == uuid.UUID(_TWIN_ID) # survivor = the twin assert fold_args[3] == uuid.UUID(_UPDATED_ID) # folded-from = the updated row - delete_args = conn.execute.await_args_list[1].args + assert fold_args[6] == conn.fetchrow_result["source_memory_ids"] + conn.execute.assert_awaited_once() + delete_args = conn.execute.await_args.args assert delete_args[1] == uuid.UUID(_UPDATED_ID) # the updated row is deleted @@ -201,7 +307,8 @@ async def test_dedup_update_keep_does_not_merge() -> None: with _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.98)]): await _dedup_reconcile_update(**kwargs) llm.call.assert_awaited_once() - conn.execute.assert_not_called() # kept distinct → neither fold nor delete + conn.fetchval.assert_not_called() # kept distinct → no fold + conn.execute.assert_not_called() # → no delete async def test_dedup_update_excludes_self() -> None: @@ -211,6 +318,7 @@ async def test_dedup_update_excludes_self() -> None: with _patch_probe([_obs("its own current text", 1.0, oid=_UPDATED_ID)]): await _dedup_reconcile_update(**kwargs) llm.call.assert_not_called() + conn.fetchval.assert_not_called() conn.execute.assert_not_called() @@ -219,6 +327,89 @@ async def test_dedup_update_no_twin_above_threshold() -> None: with _patch_probe([_obs("loosely related", 0.8)]): await _dedup_reconcile_update(**kwargs) llm.call.assert_not_called() + conn.fetchval.assert_not_called() + conn.execute.assert_not_called() + + +async def test_dedup_create_twin_vanished_returns_none_so_caller_creates() -> None: + # If the twin is deleted during the (connection-free) LLM window, the fold UPDATE matches + # no row (fetchval -> None); the helper must return None so the caller still CREATEs. + kwargs, conn, llm = _ctx() + conn.fetchval_result = None + llm.call.return_value = _DedupDecision(action="merge", text="merged text") + with ( + _patch_embed(), + _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.99)]), + ): + result = await _dedup_reconcile_create(**kwargs) + assert result is None # twin gone → don't drop the CREATE + conn.fetchval.assert_awaited_once() + + +async def test_dedup_create_fold_uses_only_live_new_sources() -> None: + kwargs, conn, llm = _ctx() + live_source_id = uuid.uuid4() + deleted_source_id = uuid.uuid4() + kwargs["create_source_ids"] = [deleted_source_id, live_source_id] + conn.live_rows = [{"id": live_source_id}] + llm.call.return_value = _DedupDecision(action="merge", text="merged text") + with ( + _patch_embed(), + _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.99)]), + ): + result = await _dedup_reconcile_create(**kwargs) + assert result == _TWIN_ID + conn.fetchval.assert_awaited_once() + assert conn.fetchval.await_args.args[2] == [live_source_id] + + +async def test_dedup_create_all_new_sources_deleted_returns_none() -> None: + kwargs, conn, llm = _ctx() + kwargs["create_source_ids"] = [uuid.uuid4(), uuid.uuid4()] + conn.live_rows = [] + llm.call.return_value = _DedupDecision(action="merge", text="merged text") + with ( + _patch_embed(), + _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.99)]), + ): + result = await _dedup_reconcile_create(**kwargs) + assert result is None + conn.fetchval.assert_not_called() + + +async def test_dedup_update_twin_vanished_does_not_delete_updated() -> None: + # If the fold matches no row (twin vanished mid-window), the updated row must NOT be deleted. + kwargs, conn, llm = _update_ctx() + conn.fetchval_result = None + llm.call.return_value = _DedupDecision(action="merge", text="merged text") + with _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.98)]): + await _dedup_reconcile_update(**kwargs) + conn.fetchval.assert_awaited_once() # fold attempted + conn.execute.assert_not_called() # but no delete, since the fold touched nothing + + +async def test_dedup_update_fold_uses_only_live_updated_sources() -> None: + kwargs, conn, llm = _update_ctx() + live_source_id = uuid.uuid4() + deleted_source_id = uuid.uuid4() + conn.fetchrow_result = {"source_memory_ids": [deleted_source_id, live_source_id]} + conn.live_rows = [{"id": live_source_id}] + llm.call.return_value = _DedupDecision(action="merge", text="merged text") + with _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.98)]): + await _dedup_reconcile_update(**kwargs) + conn.fetchval.assert_awaited_once() + assert conn.fetchval.await_args.args[6] == [live_source_id] + conn.execute.assert_awaited_once() + + +async def test_dedup_update_all_updated_sources_deleted_skips_fold_and_delete() -> None: + kwargs, conn, llm = _update_ctx() + conn.fetchrow_result = {"source_memory_ids": [uuid.uuid4(), uuid.uuid4()]} + conn.live_rows = [] + llm.call.return_value = _DedupDecision(action="merge", text="merged text") + with _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.98)]): + await _dedup_reconcile_update(**kwargs) + conn.fetchval.assert_not_called() conn.execute.assert_not_called() @@ -257,3 +448,185 @@ def test_dedup_active_skipped_on_oracle() -> None: def test_dedup_active_none_config() -> None: assert _dedup_active(None) is False + + +async def test_process_batch_creates_when_dedup_target_vanished() -> None: + # Caller contract: when _dedup_reconcile_create returns None (twin vanished mid-window), + # _process_memory_batch must still CREATE the observation instead of dropping it. + from hindsight_api.engine.consolidation import consolidator as C + + mem_id = str(uuid.uuid4()) + memories = [{"id": mem_id, "text": "Uzbek YouTube content is very rich.", "tags": []}] + create = C._CreateAction(text="Uzbek YouTube content is very rich.", source_fact_ids=[mem_id]) + llm_result = C._BatchLLMResult(creates=[create]) + + memory_engine = types.SimpleNamespace( + _consolidation_llm_config=types.SimpleNamespace(with_config=lambda *a, **k: object()) + ) + + with ( + patch.object( + C, + "_find_related_observations", + new=AsyncMock(return_value=types.SimpleNamespace(results=[], source_facts={})), + ), + patch.object(C, "_consolidate_batch_with_llm", new=AsyncMock(return_value=llm_result)), + patch.object(C, "_effective_scope_limit", return_value=-1), + patch.object(C, "_dedup_active", return_value=True), + patch.object(C, "_dedup_reconcile_create", new=AsyncMock(return_value=None)), + patch.object(C, "_execute_create_action", new=AsyncMock(return_value="created")) as create_action, + ): + result = await C._process_memory_batch( + pool=object(), + memory_engine=memory_engine, + llm_config=object(), + bank_id="bank1", + memories=memories, + request_context=object(), + config=object(), + ) + + create_action.assert_awaited_once() + assert create_action.await_args.kwargs["text"] == "Uzbek YouTube content is very rich." + assert create_action.await_args.kwargs["source_memory_ids"] == [mem_id] + assert result == ([{"action": "created"}], 0, False) + + +async def test_process_batch_reports_skipped_when_create_skipped() -> None: + # _execute_create_action returns "skipped" (all sources deleted in the write txn) -> + # _process_memory_batch must NOT mark the memory created; it falls through to skipped. + from hindsight_api.engine.consolidation import consolidator as C + + mem_id = str(uuid.uuid4()) + memories = [{"id": mem_id, "text": "Uzbek YouTube content is very rich.", "tags": []}] + create = C._CreateAction(text="Uzbek YouTube content is very rich.", source_fact_ids=[mem_id]) + llm_result = C._BatchLLMResult(creates=[create]) + memory_engine = types.SimpleNamespace( + _consolidation_llm_config=types.SimpleNamespace(with_config=lambda *a, **k: object()) + ) + with ( + patch.object( + C, + "_find_related_observations", + new=AsyncMock(return_value=types.SimpleNamespace(results=[], source_facts={})), + ), + patch.object(C, "_consolidate_batch_with_llm", new=AsyncMock(return_value=llm_result)), + patch.object(C, "_effective_scope_limit", return_value=-1), + patch.object(C, "_dedup_active", return_value=True), + patch.object(C, "_dedup_reconcile_create", new=AsyncMock(return_value=None)), + patch.object(C, "_execute_create_action", new=AsyncMock(return_value="skipped")), + ): + result = await C._process_memory_batch( + pool=object(), + memory_engine=memory_engine, + llm_config=object(), + bank_id="bank1", + memories=memories, + request_context=object(), + config=object(), + ) + assert result == ([{"action": "skipped", "reason": "no_durable_knowledge"}], 0, False) + + +async def test_process_batch_reports_created_when_create_created() -> None: + # _execute_create_action returns "created" -> the memory is marked created. + from hindsight_api.engine.consolidation import consolidator as C + + mem_id = str(uuid.uuid4()) + memories = [{"id": mem_id, "text": "Uzbek YouTube content is very rich.", "tags": []}] + create = C._CreateAction(text="Uzbek YouTube content is very rich.", source_fact_ids=[mem_id]) + llm_result = C._BatchLLMResult(creates=[create]) + memory_engine = types.SimpleNamespace( + _consolidation_llm_config=types.SimpleNamespace(with_config=lambda *a, **k: object()) + ) + with ( + patch.object( + C, + "_find_related_observations", + new=AsyncMock(return_value=types.SimpleNamespace(results=[], source_facts={})), + ), + patch.object(C, "_consolidate_batch_with_llm", new=AsyncMock(return_value=llm_result)), + patch.object(C, "_effective_scope_limit", return_value=-1), + patch.object(C, "_dedup_active", return_value=True), + patch.object(C, "_dedup_reconcile_create", new=AsyncMock(return_value=None)), + patch.object(C, "_execute_create_action", new=AsyncMock(return_value="created")), + ): + result = await C._process_memory_batch( + pool=object(), + memory_engine=memory_engine, + llm_config=object(), + bank_id="bank1", + memories=memories, + request_context=object(), + config=object(), + ) + assert result == ([{"action": "created"}], 0, False) + + +async def test_dedup_create_fold_guards_on_twin_probe_text() -> None: + # The CREATE fold guards on the twin's probe-time text (param $4) so a concurrent rewrite + # of the survivor during the connection-free LLM window can't be clobbered by stale text. + kwargs, conn, llm = _ctx() + kwargs["create_source_ids"] = [uuid.uuid4(), uuid.uuid4()] + llm.call.return_value = _DedupDecision(action="merge", text="Uzbek content on YouTube is very rich.") + with ( + _patch_embed(), + _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.99)]), + ): + result = await _dedup_reconcile_create(**kwargs) + assert result == _TWIN_ID + conn.fetchval.assert_awaited_once() + args = conn.fetchval.await_args.args + assert args[4] == "Uzbek content on YouTube is described as very rich." # twin probe-text guard + + +async def test_dedup_update_fold_guards_on_both_texts() -> None: + # The UPDATE fold guards BOTH rows whose text fed the merge: the survivor twin ($4) and the + # just-updated row ($5), so a concurrent rewrite of either aborts the fold-and-delete. + kwargs, conn, llm = _update_ctx() + llm.call.return_value = _DedupDecision(action="merge", text="Uzbek YouTube content is very rich and growing.") + with _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.98)]): + await _dedup_reconcile_update(**kwargs) + conn.fetchval.assert_awaited_once() + fold_args = conn.fetchval.await_args.args + assert fold_args[4] == "Uzbek content on YouTube is described as very rich." # survivor probe-text guard + assert fold_args[5] == kwargs["updated_text"] # updated-row text guard + assert "FOR UPDATE" not in conn.fetchrow.await_args.args[0] # sources-first lock order + + +# ── fold/snapshot text guards actually bite (modeled twin/updated row text) ──── +# +# The fake above now models the survivor/updated row text, so a guard mismatch skips. This proves +# the WHERE ... text guards are load-bearing (not merely passed as params): deleting one flips a +# skip into a clobber/drop and fails the matching test below. + + +async def test_create_fold_skipped_when_twin_text_changed() -> None: + # The CREATE fold is text-guarded (WHERE id = $3 AND text = $4). If the twin's text was rewritten + # during the connection-free LLM window, the guard matches no row, so the helper returns None and + # the caller still CREATEs (no silent drop). Deleting `AND text = $4` makes this fail. + kwargs, conn, llm = _ctx() + llm.call.return_value = _DedupDecision(action="merge", text="merged text") + conn.current_twin_text = "the twin was rewritten during the LLM window" + with ( + _patch_embed(), + _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.99)]), + ): + result = await _dedup_reconcile_create(**kwargs) + assert result is None # guard rejected the fold -> caller CREATEs rather than merging onto a stale twin + conn.fetchval.assert_awaited_once() # the RETURNING-gated fold actually ran and matched no row + + +async def test_update_fold_skipped_when_updated_text_changed() -> None: + # The UPDATE path snapshots the updated row with a text guard (WHERE id = $1 AND text = $2). If the + # updated observation's text was rewritten during the connection-free LLM window, the snapshot + # matches no row and the reconciler bails BEFORE folding/deleting, so a concurrently-changed row is + # never merged away. Deleting the snapshot's `AND text = $2` makes this fail. + kwargs, conn, llm = _update_ctx() + llm.call.return_value = _DedupDecision(action="merge", text="merged text") + conn.current_updated_text = "the updated row was rewritten during the LLM window" + with _patch_probe([_obs("Uzbek content on YouTube is described as very rich.", 0.98)]): + await _dedup_reconcile_update(**kwargs) + conn.fetchrow.assert_awaited_once() # the text-guarded snapshot SELECT ran + conn.fetchval.assert_not_called() # snapshot matched no row -> no fold + conn.execute.assert_not_called() # -> the updated row is not deleted diff --git a/hindsight-api-slim/tests/test_consolidation_embedding_validation.py b/hindsight-api-slim/tests/test_consolidation_embedding_validation.py index 0b7c13e80..b979dd02f 100644 --- a/hindsight-api-slim/tests/test_consolidation_embedding_validation.py +++ b/hindsight-api-slim/tests/test_consolidation_embedding_validation.py @@ -1,4 +1,17 @@ +"""Deterministic unit tests for the pre-embed guards in the consolidation executors. + +Both findings are exercised against the production ``_wraps_backend`` acquisition path +(not a raw-pool sentinel): + +* zero-length embeddings are rejected BEFORE any write reaches the connection; +* when every source memory is already gone, the create/update executors short-circuit + BEFORE running the (slow) embedder, so a no-op consolidation never embeds and a + failing embedder never raises where it used to skip cleanly. +""" + +import types import uuid +from contextlib import asynccontextmanager import pytest @@ -17,25 +30,149 @@ class _FakeMemoryEngine: embeddings = _ZeroLengthEmbeddings() -class _FailingConn: +class _WriteForbiddenConn: + """Backend connection allowing only the pre-embed liveness probe. + + The preflight (``_any_live_source_memory``) runs ``fetchval``; the write path + (``transaction``/``fetchrow``/``execute``/``executemany``) must never be reached, + because the zero-length embedding is rejected first. + """ + + async def fetchval(self, *args, **kwargs): + return 1 # a live source exists -> proceed to embedding + + async def fetch(self, *args, **kwargs): + return [{"id": 1}] + + def transaction(self): + raise AssertionError("write transaction entered before the zero-length embedding was rejected") + async def fetchrow(self, *args, **kwargs): - raise AssertionError("zero-length embedding should be rejected before database insert") + raise AssertionError("INSERT reached before the zero-length embedding was rejected") + async def execute(self, *args, **kwargs): + raise AssertionError("write reached before the zero-length embedding was rejected") -@pytest.mark.asyncio -async def test_create_observation_rejects_zero_length_embedding_before_insert(monkeypatch): - source_id = uuid.uuid4() + async def executemany(self, *args, **kwargs): + raise AssertionError("write reached before the zero-length embedding was rejected") + + +class _NoLiveConn: + """Backend connection whose liveness probe reports no live source. + + Correct code short-circuits at the preflight (``fetchval`` -> None) and never embeds + or writes; every write method fails hard as a backstop. + """ + + async def fetchval(self, *args, **kwargs): + return None # no live source -> skip before embedding + + def transaction(self): + raise AssertionError("write transaction entered after all sources were dead") + + async def fetchrow(self, *args, **kwargs): + raise AssertionError("write reached after all sources were dead") + + async def execute(self, *args, **kwargs): + raise AssertionError("write reached after all sources were dead") + + +class _Backend: + """Backend-shaped stand-in matching acquire_with_retry's ``_wraps_backend`` path.""" + + _wraps_backend = True - async def fake_filter_live_source_memories(conn, bank_id, source_memory_ids): - return source_memory_ids + def __init__(self, conn): + self._conn = conn - monkeypatch.setattr(consolidator, "_filter_live_source_memories", fake_filter_live_source_memories) + @asynccontextmanager + async def acquire(self): + yield self._conn + +def _forbid_embedder(monkeypatch): + """Make any call to the embedder fail hard, proving it is never reached.""" + + async def _embedder_must_not_run(*args, **kwargs): + raise AssertionError("embedder must not run when all source memories are dead") + + monkeypatch.setattr( + "hindsight_api.engine.retain.embedding_utils.generate_embeddings_batch", + _embedder_must_not_run, + ) + + +@pytest.mark.asyncio +async def test_create_observation_rejects_zero_length_embedding_before_insert(): + # Preflight passes (fetchval -> live); the real embedder then yields a zero-length vector, + # which must be rejected before any write. _WriteForbiddenConn fails hard if a write runs. with pytest.raises(RuntimeError, match="embedding 0 has dimension 0; expected 384"): await consolidator._create_observation_directly( - conn=_FailingConn(), + pool=_Backend(_WriteForbiddenConn()), memory_engine=_FakeMemoryEngine(), bank_id="test-bank", - source_memory_ids=[source_id], + source_memory_ids=[uuid.uuid4()], observation_text="Consolidated observation text.", ) + + +@pytest.mark.asyncio +async def test_create_observation_skips_before_embedding_when_all_sources_dead(monkeypatch): + # All sources gone -> the create must short-circuit at the preflight, BEFORE the embedder + # (patched to explode if reached). No write may run. + _forbid_embedder(monkeypatch) + result = await consolidator._create_observation_directly( + pool=_Backend(_NoLiveConn()), + memory_engine=_FakeMemoryEngine(), + bank_id="test-bank", + source_memory_ids=[uuid.uuid4()], + observation_text="Consolidated observation text.", + ) + assert result == {"action": "skipped", "reason": "sources_deleted"} + + +@pytest.mark.asyncio +async def test_update_action_skips_before_embedding_when_all_sources_dead(monkeypatch): + # Same preflight contract on the UPDATE path: all sources gone -> skip (return None) BEFORE + # the embedder runs. The existing real-DB update test asserts only post-state, so it would + # pass even with the embed-first ordering this finding fixed; this pins the embed-skip directly. + _forbid_embedder(monkeypatch) + obs_id = str(uuid.uuid4()) + result = await consolidator._execute_update_action( + pool=_Backend(_NoLiveConn()), + memory_engine=_FakeMemoryEngine(), + bank_id="test-bank", + source_memory_ids=[uuid.uuid4(), uuid.uuid4()], + observation_id=obs_id, + new_text="This update must not land.", + observations=[types.SimpleNamespace(id=obs_id)], + ) + assert result is None + + +@pytest.mark.asyncio +async def test_update_action_rejects_zero_length_embedding_before_write(): + # Preflight passes (a live source exists -> fetchval=1); the real embedder then yields a + # zero-length vector, which must be rejected before any write. _WriteForbiddenConn fails + # hard if the write transaction is reached. + obs_id = str(uuid.uuid4()) + with pytest.raises(RuntimeError, match="embedding 0 has dimension 0; expected 384"): + await consolidator._execute_update_action( + pool=_Backend(_WriteForbiddenConn()), + memory_engine=_FakeMemoryEngine(), + bank_id="test-bank", + source_memory_ids=[uuid.uuid4()], + observation_id=obs_id, + new_text="Consolidated observation text.", + observations=[ + types.SimpleNamespace( + id=obs_id, + text="prior observation text", + tags=[], + occurred_start=None, + occurred_end=None, + mentioned_at=None, + source_fact_ids=[], + ) + ], + ) diff --git a/hindsight-api-slim/tests/test_document_transfer.py b/hindsight-api-slim/tests/test_document_transfer.py index abda8a190..5cf8527d3 100644 --- a/hindsight-api-slim/tests/test_document_transfer.py +++ b/hindsight-api-slim/tests/test_document_transfer.py @@ -579,11 +579,16 @@ async def test_export_import_observations(memory, request_context): source_ids = [uuid.UUID(str(i["id"])) for i in units["items"][:2]] assert len(source_ids) == 2 - # Create a real observation over those source facts. + # Create a real observation over those source facts. The helper now self-acquires a + # short-lived connection (embed runs off-connection), so we pass the backend, not a conn. backend = await memory._get_backend() - async with acquire_with_retry(backend) as conn: - async with conn.transaction(): - await _create_observation_directly(conn, memory, src, source_ids, "Alice and Bob are colleagues.") + await _create_observation_directly( + pool=backend, + memory_engine=memory, + bank_id=src, + source_memory_ids=source_ids, + observation_text="Alice and Bob are colleagues.", + ) # Export WITHOUT observations -> none in the archive (the bank may also # contain auto-consolidation observations; the flag is what gates them). diff --git a/hindsight-api-slim/tests/test_memory_curation.py b/hindsight-api-slim/tests/test_memory_curation.py index 605cf54da..79c87b353 100644 --- a/hindsight-api-slim/tests/test_memory_curation.py +++ b/hindsight-api-slim/tests/test_memory_curation.py @@ -7,13 +7,15 @@ """ import uuid +from contextlib import asynccontextmanager from unittest.mock import AsyncMock, patch import pytest from hindsight_api import RequestContext +from hindsight_api.engine.db_utils import acquire_with_retry from hindsight_api.engine.memory_engine import MemoryEngine -from hindsight_api.engine.retain import embedding_processing +from hindsight_api.engine.retain import embedding_processing, link_utils # --------------------------------------------------------------------------- # Helpers @@ -219,6 +221,129 @@ async def test_revert_moves_back_and_restores_entities(self, memory: MemoryEngin await memory.delete_bank(bank_id, request_context=request_context) + @pytest.mark.asyncio + async def test_revert_reembeds_from_survivors_when_archived_entity_pruned_midwindow( + self, memory: MemoryEngine, request_context: RequestContext + ): + bank_id = f"test-curation-rev-prune-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + pool = await memory._get_pool() + backend = await memory._get_backend() + async with pool.acquire() as conn: + m1 = await _insert_memory(conn, memory, bank_id, "Alice met Bob in Paris.") + e_alice = await _insert_entity(conn, bank_id, "Alice") + e_bob = await _insert_entity(conn, bank_id, "Bob") + await _link_entity(conn, m1, e_alice) + await _link_entity(conn, m1, e_bob) + + with ( + patch.object(memory, "submit_async_consolidation", new=AsyncMock()), + patch.object(memory, "submit_async_graph_maintenance", new=AsyncMock()), + ): + await memory.update_memory_unit(bank_id, str(m1), state="invalidated", request_context=request_context) + + calls: list[list[str]] = [] + deleted = {"done": False} + orig = memory._reembed_memory_text + + async def _spy(*, text, occurred_start, occurred_end, mentioned_at, entities): + calls.append(list(entities)) + if not deleted["done"]: + deleted["done"] = True + async with acquire_with_retry(backend) as c: + await c.execute( + "DELETE FROM entities WHERE bank_id = $1 AND canonical_name = $2", bank_id, "Bob" + ) + return await orig( + text=text, + occurred_start=occurred_start, + occurred_end=occurred_end, + mentioned_at=mentioned_at, + entities=entities, + ) + + with patch.object(memory, "_reembed_memory_text", new=_spy): + result = await memory.update_memory_unit( + bank_id, str(m1), state="valid", request_context=request_context + ) + + assert result["state"] == "valid" + assert len(calls) == 2, "a survivor mismatch re-embeds under the lock" + assert sorted(calls[1]) == ["Alice"], "revert re-embeds from the restored (survivor) entity set" + async with pool.acquire() as conn: + names = await conn.fetch( + "SELECT e.canonical_name FROM unit_entities ue " + "JOIN entities e ON e.id = ue.entity_id WHERE ue.unit_id = $1", + m1, + ) + assert {r["canonical_name"] for r in names} == {"Alice"}, "only the surviving entity is restored" + emb = await conn.fetchval("SELECT embedding FROM memory_units WHERE id = $1", m1) + assert emb is not None + + await memory.delete_bank(bank_id, request_context=request_context) + + @pytest.mark.asyncio + async def test_revert_reembeds_when_archive_text_rewritten_midwindow( + self, memory: MemoryEngine, request_context: RequestContext + ): + # #3 regression: the archive row's TEXT is rewritten during the connection-free embed + # window, so the Phase-1 embedding (old text) is stale. The revert must re-embed from the + # LOCKED archive row, not store the precomputed (now-stale) vector against the new text. + bank_id = f"test-curation-rev-txtrace-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + pool = await memory._get_pool() + backend = await memory._get_backend() + async with pool.acquire() as conn: + m1 = await _insert_memory(conn, memory, bank_id, "Alice met Bob in Paris.") + + with ( + patch.object(memory, "submit_async_consolidation", new=AsyncMock()), + patch.object(memory, "submit_async_graph_maintenance", new=AsyncMock()), + ): + await memory.update_memory_unit(bank_id, str(m1), state="invalidated", request_context=request_context) + + calls: list[str] = [] + raced = {"done": False} + orig = memory._reembed_memory_text + + async def _spy(*, text, occurred_start, occurred_end, mentioned_at, entities): + calls.append(text) + if not raced["done"]: + raced["done"] = True + # Rewrite the archived row's text on a SEPARATE connection during the + # connection-free embed window (the between-phases race #3 guards against). + async with acquire_with_retry(backend) as c: + await c.execute( + "UPDATE invalidated_memory_units SET text = $2 WHERE id = $1 AND bank_id = $3", + m1, + "Rewritten archived text.", + bank_id, + ) + return await orig( + text=text, + occurred_start=occurred_start, + occurred_end=occurred_end, + mentioned_at=mentioned_at, + entities=entities, + ) + + with patch.object(memory, "_reembed_memory_text", new=_spy): + result = await memory.update_memory_unit( + bank_id, str(m1), state="valid", request_context=request_context + ) + + assert result["state"] == "valid" + assert len(calls) == 2, "a stale archive snapshot (text rewritten mid-window) re-embeds under the lock" + assert calls[1] == "Rewritten archived text.", "the in-txn re-embed uses the locked (current) archive text" + async with pool.acquire() as conn: + row = await conn.fetchrow("SELECT text, embedding FROM memory_units WHERE id = $1", m1) + assert row["text"] == "Rewritten archived text.", "revert restores the locked archive row" + assert row["embedding"] is not None + + await memory.delete_bank(bank_id, request_context=request_context) + @pytest.mark.asyncio async def test_invalidate_idempotent_updates_reason(self, memory: MemoryEngine, request_context: RequestContext): bank_id = f"test-curation-idem-{uuid.uuid4().hex[:8]}" @@ -262,6 +387,7 @@ async def test_edit_changes_text_and_rederives(self, memory: MemoryEngine, reque async with pool.acquire() as conn: m1 = await _insert_memory(conn, memory, bank_id, "The assistant visited Paris in 2023.") obs_id = await _insert_observation(conn, bank_id, "The assistant went to Paris.", [m1]) + orig_emb = await conn.fetchval("SELECT embedding FROM memory_units WHERE id = $1", m1) with ( patch.object(memory, "submit_async_consolidation", new=AsyncMock()), @@ -279,9 +405,13 @@ async def test_edit_changes_text_and_rederives(self, memory: MemoryEngine, reque assert result["state"] == "valid" async with pool.acquire() as conn: assert await _in_live(conn, m1), "edited row stays live" - row = dict(await conn.fetchrow("SELECT text, consolidated_at FROM memory_units WHERE id = $1", m1)) + row = dict( + await conn.fetchrow("SELECT text, consolidated_at, embedding FROM memory_units WHERE id = $1", m1) + ) assert row["text"] == "The user visited Paris in 2023." assert row["consolidated_at"] is None, "edited memory re-consolidates" + assert row["embedding"] is not None, "edit re-embeds (phase split must not drop the embedding)" + assert row["embedding"] != orig_emb, "edit stores a freshly recomputed vector, not the stale one" assert str(obs_id) not in await _obs_ids(conn, bank_id), "stale observation re-derived" await memory.delete_bank(bank_id, request_context=request_context) @@ -363,6 +493,128 @@ async def test_edit_replaces_entities(self, memory: MemoryEngine, request_contex await memory.delete_bank(bank_id, request_context=request_context) + @pytest.mark.asyncio + async def test_entity_edit_reresolves_when_resolved_entity_pruned_midwindow( + self, memory: MemoryEngine, request_context: RequestContext + ): + bank_id = f"test-curation-edit-prune-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + pool = await memory._get_pool() + backend = await memory._get_backend() + async with pool.acquire() as conn: + m1 = await _insert_memory(conn, memory, bank_id, "Alice met Bob in Paris.") + + calls: list[list[str]] = [] + deleted = {"done": False} + orig_resolve = link_utils.resolve_entities_only + orig_embed = memory._reembed_memory_text + + async def _resolve_spy(*args, **kwargs): + # Resolve normally (find-or-create autocommits entities OFF the write txn), then prune + # one resolved entity BEFORE update_memory_unit reads back canonical names. This is the + # real resolve->name-fetch race: edit_plan.names is captured short, so a name-set match + # check would commit a partial set. ID-coverage must detect the missing id and re-resolve. + result = await orig_resolve(*args, **kwargs) + if not deleted["done"]: + deleted["done"] = True + async with acquire_with_retry(backend) as c: + await c.execute("DELETE FROM entities WHERE bank_id = $1 AND canonical_name = $2", bank_id, "Bob") + return result + + async def _embed_spy(*, text, occurred_start, occurred_end, mentioned_at, entities): + calls.append(list(entities)) + return await orig_embed( + text=text, + occurred_start=occurred_start, + occurred_end=occurred_end, + mentioned_at=mentioned_at, + entities=entities, + ) + + with ( + patch.object(memory, "submit_async_consolidation", new=AsyncMock()), + patch.object(memory, "submit_async_graph_maintenance", new=AsyncMock()), + patch.object(link_utils, "resolve_entities_only", new=_resolve_spy), + patch.object(memory, "_reembed_memory_text", new=_embed_spy), + ): + result = await memory.update_memory_unit( + bank_id, str(m1), entities=["Alice", "Bob"], request_context=request_context + ) + + assert result is not None + assert set(result["entities"]) == {"Alice", "Bob"}, "edit must not commit a partial entity set" + assert calls[0] == ["Alice"], ( + "resolved entity pruned before edit_plan.names was captured (the resolve->name-fetch race)" + ) + assert len(calls) == 2, "a prune mismatch re-resolves and re-embeds under the lock" + assert set(calls[1]) == {"Alice", "Bob"}, "the recovered re-embed uses the full re-resolved entity set" + async with pool.acquire() as conn: + names = await conn.fetch( + "SELECT e.canonical_name FROM unit_entities ue " + "JOIN entities e ON e.id = ue.entity_id WHERE ue.unit_id = $1", + m1, + ) + assert {r["canonical_name"] for r in names} == {"Alice", "Bob"}, "pruned entity re-created and linked" + emb = await conn.fetchval("SELECT embedding FROM memory_units WHERE id = $1", m1) + assert emb is not None + + await memory.delete_bank(bank_id, request_context=request_context) + + @pytest.mark.asyncio + async def test_edit_aborts_on_concurrent_field_change(self, memory: MemoryEngine, request_context: RequestContext): + # A concurrent edit that commits during the off-connection embed must NOT be silently + # clobbered by the precomputed (stale) edit. The Phase-2 re-lock detects the changed + # column and aborts (rollback), so the concurrent writer's text survives. + bank_id = f"test-curation-edit-race-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + pool = await memory._get_pool() + backend = await memory._get_backend() + async with pool.acquire() as conn: + m1 = await _insert_memory(conn, memory, bank_id, "Original text.") + + orig_embed = memory._reembed_memory_text + raced = {"done": False} + + async def _embed_spy(*, text, occurred_start, occurred_end, mentioned_at, entities): + # Inside the connection-free embed window, commit a concurrent text edit on a SEPARATE + # backend connection (the real between-phases race the abort guards against). + if not raced["done"]: + raced["done"] = True + async with acquire_with_retry(backend) as c: + await c.execute( + "UPDATE memory_units SET text = $2, updated_at = now() WHERE id = $1", + m1, + "Concurrently edited text.", + ) + return await orig_embed( + text=text, + occurred_start=occurred_start, + occurred_end=occurred_end, + mentioned_at=mentioned_at, + entities=entities, + ) + + with ( + patch.object(memory, "submit_async_consolidation", new=AsyncMock()), + patch.object(memory, "submit_async_graph_maintenance", new=AsyncMock()), + patch.object(memory, "_reembed_memory_text", new=_embed_spy), + ): + # A context-only edit still re-embeds (so the spy fires); the abort must fire before it + # can overwrite the racing text edit with the Phase-1 snapshot. + with pytest.raises(RuntimeError, match="modified concurrently"): + await memory.update_memory_unit( + bank_id, str(m1), context="late annotation", request_context=request_context + ) + + async with pool.acquire() as conn: + row = await conn.fetchrow("SELECT text, context FROM memory_units WHERE id = $1", m1) + assert row["text"] == "Concurrently edited text.", "concurrent edit preserved (no lost update)" + assert row["context"] != "late annotation", "aborted edit did not apply" + + await memory.delete_bank(bank_id, request_context=request_context) + @pytest.mark.asyncio async def test_edit_empty_entities_detaches_all(self, memory: MemoryEngine, request_context: RequestContext): bank_id = f"test-curation-editent0-{uuid.uuid4().hex[:8]}" @@ -405,6 +657,253 @@ async def test_cannot_edit_invalidated_memory(self, memory: MemoryEngine, reques await memory.delete_bank(bank_id, request_context=request_context) + @pytest.mark.asyncio + async def test_entity_edit_embed_failure_reclaims_orphan_entities( + self, memory: MemoryEngine, request_context: RequestContext + ): + bank_id = f"test-curation-edit-embedfail-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + pool = await memory._get_pool() + async with pool.acquire() as conn: + m1 = await _insert_memory(conn, memory, bank_id, "A standalone fact.") + + try: + consolidation_mock = AsyncMock() + # Leave submit_async_graph_maintenance REAL: under SyncTaskBackend the sweep runs inline, + # so we can assert the leaked entities are actually pruned (not just that a mock was awaited). + with ( + patch.object(memory, "submit_async_consolidation", new=consolidation_mock), + patch.object(memory, "_reembed_memory_text", new=AsyncMock(side_effect=RuntimeError("embedder down"))), + ): + with pytest.raises(RuntimeError, match="embedder down"): + await memory.update_memory_unit( + bank_id, str(m1), entities=["Alice", "Bob"], request_context=request_context + ) + + # resolve_entities_only autocommitted Alice/Bob in Phase 1, but the edit never linked them + # (the re-embed raised first). The failure path must enqueue this unit so the bank-wide + # orphan-entity prune runs and reclaims them; otherwise they leak. + async with pool.acquire() as conn: + orphan_count = await conn.fetchval("SELECT count(*) FROM entities WHERE bank_id = $1", bank_id) + assert orphan_count == 0, "orphan entities from the failed edit were reclaimed by graph maintenance" + consolidation_mock.assert_not_awaited() # a failed edit must not trigger consolidation + finally: + await memory.delete_bank(bank_id, request_context=request_context) + + @pytest.mark.asyncio + async def test_entity_edit_resolve_partial_commit_reclaims_orphan_entities( + self, memory: MemoryEngine, request_context: RequestContext + ): + bank_id = f"test-curation-edit-resolvefail-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + pool = await memory._get_pool() + backend = await memory._get_backend() + async with pool.acquire() as conn: + m1 = await _insert_memory(conn, memory, bank_id, "A standalone fact.") + + # Simulate resolve_entities_only autocommitting an entity on the Phase-1 (autocommit) connection + # and THEN raising. entities_maybe_committed is set BEFORE the resolve call, so the failure path + # must still enqueue the unit and let the inline sweep reclaim the committed orphan. If the flag + # were set after the call, this entity would leak. + async def _partially_commit_then_raise(*args, **kwargs): + phase1_conn = args[1] + await phase1_conn.execute( + "INSERT INTO entities (id, bank_id, canonical_name) VALUES ($1, $2, $3)", + uuid.uuid4(), + bank_id, + "Ghost", + ) + # Prove the insert autocommitted on the Phase-1 (autocommit) connection: a SEPARATE + # backend connection must see it. If Phase 1 were wrapped in a transaction, this would + # be 0 and the later orphan_count == 0 would prove nothing (a rollback, not the sweep). + async with acquire_with_retry(backend) as other: + seen = await other.fetchval( + "SELECT count(*) FROM entities WHERE bank_id = $1 AND canonical_name = $2", + bank_id, + "Ghost", + ) + assert seen == 1, "Ghost must be autocommitted (visible cross-connection) before the resolver raises" + raise RuntimeError("resolver down") + + try: + consolidation_mock = AsyncMock() + # submit_async_graph_maintenance stays REAL so the sweep runs inline under SyncTaskBackend. + # resolve_entities_only is patched at its module path because update_memory_unit imports it + # at call time (`from .retain.link_utils import resolve_entities_only`). + with ( + patch.object(memory, "submit_async_consolidation", new=consolidation_mock), + patch( + "hindsight_api.engine.retain.link_utils.resolve_entities_only", + new=_partially_commit_then_raise, + ), + ): + with pytest.raises(RuntimeError, match="resolver down"): + await memory.update_memory_unit( + bank_id, str(m1), entities=["Alice", "Bob"], request_context=request_context + ) + + async with pool.acquire() as conn: + orphan_count = await conn.fetchval("SELECT count(*) FROM entities WHERE bank_id = $1", bank_id) + assert orphan_count == 0, "an entity committed before the resolver raised was reclaimed by the sweep" + consolidation_mock.assert_not_awaited() # a failed edit must not trigger consolidation + finally: + await memory.delete_bank(bank_id, request_context=request_context) + + @pytest.mark.asyncio + async def test_context_edit_reembeds_when_unit_entities_change_midwindow( + self, memory: MemoryEngine, request_context: RequestContext + ): + # A context-only edit (resolved_for_unit is None) re-embeds from the unit's CURRENT entity + # names. If a concurrent writer changes unit_entities while the embedder runs off-connection, + # the Phase-2 lock re-reads the names; because they differ from the Phase-1 snapshot, the + # non-entity-edit branch re-embeds in-txn so the stored vector never names a stale entity set. + # unit_entities is NOT an abort-guarded column, so this re-embeds rather than aborting. + bank_id = f"test-curation-ctx-reembed-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + pool = await memory._get_pool() + backend = await memory._get_backend() + async with pool.acquire() as conn: + m1 = await _insert_memory(conn, memory, bank_id, "A fact about gardening.") + e1 = await _insert_entity(conn, bank_id, "Alpha") + e2 = await _insert_entity(conn, bank_id, "Beta") + await _link_entity(conn, m1, e1) + + calls: list[list[str]] = [] + orig_embed = memory._reembed_memory_text + linked = {"done": False} + + async def _embed_spy(*, text, occurred_start, occurred_end, mentioned_at, entities): + calls.append(list(entities)) + # On the first (Phase-1) embed, link a second entity on a SEPARATE connection so the + # Phase-2 re-lock observes a changed unit_entities set. + if not linked["done"]: + linked["done"] = True + async with acquire_with_retry(backend) as c: + await _link_entity(c, m1, e2) + return await orig_embed( + text=text, + occurred_start=occurred_start, + occurred_end=occurred_end, + mentioned_at=mentioned_at, + entities=entities, + ) + + with ( + patch.object(memory, "submit_async_consolidation", new=AsyncMock()), + patch.object(memory, "submit_async_graph_maintenance", new=AsyncMock()), + patch.object(memory, "_reembed_memory_text", new=_embed_spy), + ): + result = await memory.update_memory_unit( + bank_id, str(m1), context="late note", request_context=request_context + ) + + assert result is not None + assert len(calls) == 2, "a mid-window unit_entities change triggers an in-txn re-embed" + assert calls[0] == ["Alpha"], "Phase-1 embed used the unit's entity set at read time" + assert set(calls[1]) == {"Alpha", "Beta"}, "Phase-2 re-embed uses the concurrently-updated entity set" + async with pool.acquire() as conn: + names = await conn.fetch( + "SELECT e.canonical_name FROM unit_entities ue " + "JOIN entities e ON e.id = ue.entity_id WHERE ue.unit_id = $1", + m1, + ) + row = await conn.fetchrow("SELECT context, embedding FROM memory_units WHERE id = $1", m1) + assert {r["canonical_name"] for r in names} == {"Alpha", "Beta"}, "both entities remain linked" + assert row["context"] == "late note", "the context edit applied" + assert row["embedding"] is not None + + await memory.delete_bank(bank_id, request_context=request_context) + + @pytest.mark.asyncio + async def test_edit_then_invalidate_archives_edited_text( + self, memory: MemoryEngine, request_context: RequestContext + ): + # A single call that BOTH edits text and invalidates applies the edit first (Phase-2 UPDATE), + # then moves the freshly-edited row to the archive -- so the archived text is the corrected one. + bank_id = f"test-curation-edit-invalidate-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + pool = await memory._get_pool() + async with pool.acquire() as conn: + m1 = await _insert_memory(conn, memory, bank_id, "Original.") + + with ( + patch.object(memory, "submit_async_consolidation", new=AsyncMock()), + patch.object(memory, "submit_async_graph_maintenance", new=AsyncMock()), + ): + await memory.update_memory_unit( + bank_id, str(m1), text="Corrected.", state="invalidated", request_context=request_context + ) + + async with pool.acquire() as conn: + assert not await _in_live(conn, m1), "the row was moved out of memory_units" + arch = await _archive_row(conn, m1) + assert arch is not None, "the row landed in the archive" + assert arch["text"] == "Corrected.", "the edit applied before the archive move" + + await memory.delete_bank(bank_id, request_context=request_context) + + @pytest.mark.asyncio + async def test_entity_edit_aborts_and_reclaims_orphans_on_concurrent_change( + self, memory: MemoryEngine, request_context: RequestContext + ): + # An entity-changing edit autocommits its resolved entities in Phase 1 (off-txn). If a + # concurrent edit changes an abort-guarded column while the embedder runs, the Phase-2 re-lock + # aborts (rollback) BEFORE the entities are linked, leaving them committed orphans. The + # finally-block enqueue + forced graph maintenance must reclaim them. + bank_id = f"test-curation-edit-abort-reclaim-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + pool = await memory._get_pool() + backend = await memory._get_backend() + async with pool.acquire() as conn: + m1 = await _insert_memory(conn, memory, bank_id, "A standalone fact.") + + orig_embed = memory._reembed_memory_text + raced = {"done": False} + + async def _embed_spy(*, text, occurred_start, occurred_end, mentioned_at, entities): + # Commit a concurrent text edit during the off-connection embed so the Phase-2 re-lock + # detects the changed column and aborts. + if not raced["done"]: + raced["done"] = True + async with acquire_with_retry(backend) as c: + await c.execute( + "UPDATE memory_units SET text = $2, updated_at = now() WHERE id = $1", + m1, + "Concurrently edited text.", + ) + return await orig_embed( + text=text, + occurred_start=occurred_start, + occurred_end=occurred_end, + mentioned_at=mentioned_at, + entities=entities, + ) + + consolidation_mock = AsyncMock() + # Leave submit_async_graph_maintenance REAL so the inline SyncTaskBackend sweep reclaims orphans. + with ( + patch.object(memory, "submit_async_consolidation", new=consolidation_mock), + patch.object(memory, "_reembed_memory_text", new=_embed_spy), + ): + with pytest.raises(RuntimeError, match="modified concurrently"): + await memory.update_memory_unit( + bank_id, str(m1), entities=["Alice", "Bob"], request_context=request_context + ) + + async with pool.acquire() as conn: + orphan_count = await conn.fetchval("SELECT count(*) FROM entities WHERE bank_id = $1", bank_id) + row = await conn.fetchrow("SELECT text FROM memory_units WHERE id = $1", m1) + assert orphan_count == 0, "entities autocommitted in Phase 1 were reclaimed after the Phase-2 abort" + assert row["text"] == "Concurrently edited text.", "the concurrent edit survived (no lost update)" + consolidation_mock.assert_not_awaited() + + await memory.delete_bank(bank_id, request_context=request_context) + # --------------------------------------------------------------------------- # Guards / listing / recall @@ -544,3 +1043,64 @@ def _hit(res) -> bool: assert not _hit(after), "invalidated fact must be excluded from recall" await memory.delete_bank(bank_id, request_context=request_context) + + +# --------------------------------------------------------------------------- +# Update mental model (embed-before-acquire) +# --------------------------------------------------------------------------- + + +class TestUpdateMentalModel: + @pytest.mark.asyncio + async def test_update_mental_model_embeds_before_acquire( + self, memory: MemoryEngine, request_context: RequestContext + ): + # The new embedding is computed BEFORE a pooled connection is acquired, so a slow embedder + # never pins a DB connection. (_authenticate_tenant does not touch the pool.) + bank_id = f"test-mm-embed-order-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + order: list[str] = [] + real_acquire = acquire_with_retry + + async def _embed_spy(_embeddings, _texts): + order.append("embed") + # 384 dims to satisfy the mental_models.embedding vector(384) cast; the value is + # irrelevant (the random id matches no row), only the embed-before-acquire order matters. + return [[0.0] * 384] + + @asynccontextmanager + async def _acquire_spy(*args, **kwargs): + order.append("acquire") + async with real_acquire(*args, **kwargs) as conn: + yield conn + + with ( + patch("hindsight_api.engine.retain.embedding_utils.generate_embeddings_batch", new=_embed_spy), + patch("hindsight_api.engine.memory_engine.acquire_with_retry", new=_acquire_spy), + ): + await memory.update_mental_model( + bank_id, str(uuid.uuid4()), content="new content", request_context=request_context + ) + + assert "embed" in order, "a content update must compute an embedding" + assert "acquire" in order, "the write path must acquire a connection" + assert order.index("embed") < order.index("acquire"), "embed must happen before acquiring a connection" + + await memory.delete_bank(bank_id, request_context=request_context) + + @pytest.mark.asyncio + async def test_update_mental_model_skips_embed_when_content_none( + self, memory: MemoryEngine, request_context: RequestContext + ): + # No content change -> no embedding (the embed is gated on `content is not None`). + bank_id = f"test-mm-no-embed-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + embed_mock = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + with patch("hindsight_api.engine.retain.embedding_utils.generate_embeddings_batch", new=embed_mock): + await memory.update_mental_model(bank_id, str(uuid.uuid4()), tags=["x"], request_context=request_context) + + embed_mock.assert_not_awaited() + + await memory.delete_bank(bank_id, request_context=request_context) diff --git a/hindsight-api-slim/tests/test_observation_invalidation.py b/hindsight-api-slim/tests/test_observation_invalidation.py index f3f4c6b9f..677a355cb 100644 --- a/hindsight-api-slim/tests/test_observation_invalidation.py +++ b/hindsight-api-slim/tests/test_observation_invalidation.py @@ -15,6 +15,7 @@ import pytest from hindsight_api import RequestContext +from hindsight_api.engine.db_utils import acquire_with_retry from hindsight_api.engine.memory_engine import MemoryEngine # --------------------------------------------------------------------------- @@ -820,27 +821,28 @@ async def test_create_observation_filters_deleted_source_memories( bank_id = f"test-race-create-filter-{uuid.uuid4().hex[:8]}" await _ensure_bank(memory, bank_id, request_context) - pool = await memory._get_pool() - async with pool.acquire() as conn: + backend = await memory._get_backend() + async with acquire_with_retry(backend) as conn: live = await _insert_memory(conn, bank_id, "Alice loves hiking.") - dead = uuid.uuid4() # never existed — stands in for a concurrently deleted source - - result = await _create_observation_directly( - conn=conn, - memory_engine=memory, - bank_id=bank_id, - source_memory_ids=[live, dead], - observation_text="Alice enjoys hiking regularly.", - ) + dead = uuid.uuid4() # never existed -- stands in for a concurrently deleted source + + result = await _create_observation_directly( + pool=backend, + memory_engine=memory, + bank_id=bank_id, + source_memory_ids=[live, dead], + observation_text="Alice enjoys hiking regularly.", + ) - assert result["action"] == "created" + assert result["action"] == "created" + async with acquire_with_retry(backend) as conn: stored = await conn.fetchval( "SELECT source_memory_ids FROM memory_units WHERE id = $1", uuid.UUID(result["observation_id"]), ) - stored_set = {str(s) for s in stored} - assert str(live) in stored_set - assert str(dead) not in stored_set, "Deleted source must not appear in stored observation" + stored_set = {str(s) for s in stored} + assert str(live) in stored_set + assert str(dead) not in stored_set, "Deleted source must not appear in stored observation" await memory.delete_bank(bank_id, request_context=request_context) @@ -853,21 +855,21 @@ async def test_create_observation_skipped_when_all_sources_deleted( bank_id = f"test-race-create-skip-{uuid.uuid4().hex[:8]}" await _ensure_bank(memory, bank_id, request_context) - pool = await memory._get_pool() - async with pool.acquire() as conn: - result = await _create_observation_directly( - conn=conn, - memory_engine=memory, - bank_id=bank_id, - source_memory_ids=[uuid.uuid4(), uuid.uuid4()], - observation_text="All sources gone.", - ) + backend = await memory._get_backend() + result = await _create_observation_directly( + pool=backend, + memory_engine=memory, + bank_id=bank_id, + source_memory_ids=[uuid.uuid4(), uuid.uuid4()], + observation_text="All sources gone.", + ) - assert result["action"] == "skipped" - assert result["reason"] == "sources_deleted" + assert result["action"] == "skipped" + assert result["reason"] == "sources_deleted" + async with acquire_with_retry(backend) as conn: obs_ids = await _get_observation_ids(conn, bank_id) - assert obs_ids == [], "No observation row should exist" + assert obs_ids == [], "No observation row should exist" await memory.delete_bank(bank_id, request_context=request_context) @@ -881,33 +883,137 @@ async def test_update_observation_skipped_when_all_new_sources_deleted( bank_id = f"test-race-update-skip-{uuid.uuid4().hex[:8]}" await _ensure_bank(memory, bank_id, request_context) - pool = await memory._get_pool() - async with pool.acquire() as conn: + backend = await memory._get_backend() + async with acquire_with_retry(backend) as conn: original_source = await _insert_memory(conn, bank_id, "Alice hikes.") obs_id = await _insert_observation(conn, bank_id, "Alice is a hiker.", [original_source]) - original_text = "Alice is a hiker." - - observation_model = MemoryFact( - id=str(obs_id), - text=original_text, - fact_type="observation", - source_fact_ids=[str(original_source)], - tags=[], - ) + original_text = "Alice is a hiker." + + observation_model = MemoryFact( + id=str(obs_id), + text=original_text, + fact_type="observation", + source_fact_ids=[str(original_source)], + tags=[], + ) + + result = await _execute_update_action( + pool=backend, + memory_engine=memory, + bank_id=bank_id, + source_memory_ids=[uuid.uuid4(), uuid.uuid4()], # all dead + observation_id=str(obs_id), + new_text="This update must not land.", + observations=[observation_model], + ) + assert result is None, "a skipped update returns None so the caller runs no follow-on dedup" + + async with acquire_with_retry(backend) as conn: + row = await conn.fetchrow("SELECT text, source_memory_ids FROM memory_units WHERE id = $1", obs_id) + assert row["text"] == original_text, "Observation text must not change" + stored_sources = {str(s) for s in row["source_memory_ids"]} + assert stored_sources == {str(original_source)}, "Dead sources must not be appended" + + await memory.delete_bank(bank_id, request_context=request_context) + + @pytest.mark.asyncio + async def test_snapshot_then_apply_sweeps_observation_after_source_delete( + self, memory: MemoryEngine, request_context: RequestContext + ): + # The retain full-replace path snapshots the affected observations BEFORE deleting their + # source memories, then applies the deletion AFTER (SOURCE -> OBSERVATION lock order, the P1 + # deadlock fix). Verify the split helpers still sweep the derived observation and reset the + # surviving co-source even when the source row is deleted between snapshot and apply. + from hindsight_api.engine.retain.fact_storage import ( + _apply_stale_observation_deletion, + _snapshot_stale_observations, + ) + + bank_id = f"test-snapshot-apply-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + backend = await memory._get_backend() + async with acquire_with_retry(backend) as conn: + s1 = await _insert_memory(conn, bank_id, "Alice hikes on Mondays.") + s2 = await _insert_memory(conn, bank_id, "Alice hikes on Fridays.") + obs = await _insert_observation(conn, bank_id, "Alice hikes regularly.", [s1, s2]) + + async with acquire_with_retry(backend) as conn: + async with conn.transaction(): + obs_ids, remaining = await _snapshot_stale_observations(conn, bank_id, [s1], ops=backend.ops) + assert obs in obs_ids, "the derived observation is snapshotted from the outgoing source" + assert s2 in remaining, "the surviving co-source is captured for consolidated_at reset" + assert s1 not in remaining, "the outgoing source is not in the reset set" + # Delete the source FIRST (the new lock order), then apply the observation deletion. + await conn.execute("DELETE FROM memory_units WHERE id = $1", s1) + invalidated = await _apply_stale_observation_deletion(conn, bank_id, obs_ids, remaining) + assert invalidated == 1 + + async with acquire_with_retry(backend) as conn: + obs_remaining = await conn.fetchval("SELECT count(*) FROM memory_units WHERE id = $1", obs) + s2_consolidated = await _get_consolidated_at(conn, s2) + assert obs_remaining == 0, "the derived observation was swept after the source delete" + assert s2_consolidated is None, "the surviving co-source was reset for re-consolidation" + + await memory.delete_bank(bank_id, request_context=request_context) + + @pytest.mark.asyncio + async def test_update_observation_skipped_when_source_deleted_after_preflight( + self, memory: MemoryEngine, request_context: RequestContext + ): + # The preflight sees the source live, but it is deleted while the embedder runs + # off-connection. The authoritative in-txn FOR SHARE liveness guard (not the preflight) must + # then skip the update, so a dead source is never written back into the observation. + from hindsight_api.engine.consolidation.consolidator import _execute_update_action + from hindsight_api.engine.response_models import MemoryFact + + bank_id = f"test-race-update-inflight-{uuid.uuid4().hex[:8]}" + await _ensure_bank(memory, bank_id, request_context) + + backend = await memory._get_backend() + async with acquire_with_retry(backend) as conn: + source = await _insert_memory(conn, bank_id, "Alice hikes.") + obs_id = await _insert_observation(conn, bank_id, "Alice is a hiker.", [source]) + original_text = "Alice is a hiker." + + observation_model = MemoryFact( + id=str(obs_id), + text=original_text, + fact_type="observation", + source_fact_ids=[str(source)], + tags=[], + ) - await _execute_update_action( - conn=conn, + deleted = {"done": False} + + async def _embed_spy(_embeddings, _texts): + # Delete the (preflight-live) source DURING the off-connection embed, so only the in-txn + # FOR SHARE guard can catch it. + if not deleted["done"]: + deleted["done"] = True + async with acquire_with_retry(backend) as c: + await c.execute("DELETE FROM memory_units WHERE id = $1", source) + return [[0.1, 0.2, 0.3]] + + with patch( + "hindsight_api.engine.retain.embedding_utils.generate_embeddings_batch", + new=_embed_spy, + ): + result = await _execute_update_action( + pool=backend, memory_engine=memory, bank_id=bank_id, - source_memory_ids=[uuid.uuid4(), uuid.uuid4()], # all dead + source_memory_ids=[source], observation_id=str(obs_id), new_text="This update must not land.", observations=[observation_model], ) + assert result is None, "the in-txn FOR SHARE liveness skip returns None (no follow-on dedup)" + async with acquire_with_retry(backend) as conn: row = await conn.fetchrow("SELECT text, source_memory_ids FROM memory_units WHERE id = $1", obs_id) - assert row["text"] == original_text, "Observation text must not change" - stored_sources = {str(s) for s in row["source_memory_ids"]} - assert stored_sources == {str(original_source)}, "Dead sources must not be appended" + assert row["text"] == original_text, "observation text unchanged after the skipped update" + stored_sources = {str(s) for s in row["source_memory_ids"]} + assert stored_sources == {str(source)}, "no dead source appended" await memory.delete_bank(bank_id, request_context=request_context)