Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 74 additions & 72 deletions src/gaia/rag/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,7 +1600,17 @@ def _get_last_n_tokens(self, text: str, n_tokens: int) -> str:
return trimmed[first_space + 1 :]
return trimmed

def _create_vector_index(self, chunks: List[str]) -> tuple:
def _create_faiss_index(self, embeddings: "np.ndarray"):
"""Build a FAISS L2 index from precomputed embeddings."""
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
# pylint: disable=no-value-for-parameter
index.add(embeddings.astype("float32"))
return index

def _create_vector_index(
self, chunks: List[str], return_embeddings: bool = False
) -> tuple:
"""Create FAISS vector index from chunks with progress reporting."""
import time as time_module # pylint: disable=reimported

Expand Down Expand Up @@ -1635,10 +1645,7 @@ def _create_vector_index(self, chunks: List[str]) -> tuple:
print("\n 🏗️ Building FAISS search index...")

index_start = time_module.time()
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
# pylint: disable=no-value-for-parameter
index.add(embeddings.astype("float32"))
index = self._create_faiss_index(embeddings)
index_duration = time_module.time() - index_start

if self.config.show_stats:
Expand All @@ -1654,6 +1661,8 @@ def _create_vector_index(self, chunks: List[str]) -> tuple:
f"📚 Index ready with {index.ntotal} vectors "
f"(embed: {embed_duration:.2f}s, index: {index_duration:.2f}s)"
)
if return_embeddings:
return index, chunks, embeddings
return index, chunks

def _snapshot_query_state(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -2107,38 +2116,60 @@ def index_document(self, file_path: str) -> Dict[str, Any]:
stats["memory_limit_reached"] = True
return stats

# Track chunk indices for this file and publish only after the
# rebuilt index succeeds so partial state is never visible.
# Encode once; reuse for both the global and per-file index.
if self.index is None:
rebuilt_chunks_source = list(cached_chunks)
new_index, rebuilt_chunks, file_embeddings = (
self._create_vector_index(
list(cached_chunks), return_embeddings=True
)
)
file_chunk_indices = list(range(len(cached_chunks)))
rebuilt_chunk_to_file = {
idx: file_path for idx in file_chunk_indices
}
else:
old_chunks = list(self.chunks)
old_count = len(old_chunks)
rebuilt_chunks_source = old_chunks + list(cached_chunks)
start_idx = len(old_chunks)
file_chunk_indices = list(
range(start_idx, start_idx + len(cached_chunks))
)
rebuilt_chunks = old_chunks + list(cached_chunks)
rebuilt_chunk_to_file = dict(self.chunk_to_file)
for chunk_idx in file_chunk_indices:
rebuilt_chunk_to_file[chunk_idx] = file_path
if self.config.show_stats:
print(
f" 🔄 Rebuilding index ({old_count} + {len(cached_chunks)} = {len(rebuilt_chunks_source)} chunks)"
f" ➕ Appending {len(cached_chunks)} cached chunks to existing index ({old_count} -> {len(rebuilt_chunks)} total)"
)
self._load_embedder()
file_embeddings = self._encode_texts(
list(cached_chunks), show_progress=False
)

rebuilt_index, rebuilt_chunks = self._create_vector_index(
rebuilt_chunks_source
)
self.index = rebuilt_index
try:
file_index = self._create_faiss_index(file_embeddings)
except Exception as _e: # pylint: disable=broad-except
self.log.debug(
"Couldn't pre-build per-file index for %s: %s",
file_path,
_e,
)
file_index = None

if self.index is None:
self.index = new_index
else:
# Append after the per-file index is ready.
self.index.add(file_embeddings.astype("float32"))
self.chunks = rebuilt_chunks
self.chunk_to_file = rebuilt_chunk_to_file
self.file_to_chunk_indices[file_path] = file_chunk_indices

if file_index is not None:
self.file_indices[file_path] = file_index
self.file_embeddings[file_path] = file_embeddings

# Restore metadata in memory
if cached_full_text or cached_metadata:
self.file_metadata[file_path] = {
Expand All @@ -2148,32 +2179,6 @@ def index_document(self, file_path: str) -> Dict[str, Any]:

self.indexed_files.add(file_path)

# Build per-file FAISS index NOW so retrieval-time queries
# don't have to rebuild it on every call. Pre-#1030
# follow-up: this block was only present on the
# fresh-index path (~line 2289 below), so any document
# loaded from cache hit ``cached_file_index is None`` in
# _retrieve_chunks_from_file and rebuilt the FAISS index
# from scratch on every query — adding ~3 s × N queries
# of avoidable work. One-time cost on cache load instead.
try:
self._load_embedder()
_file_embeddings = self._encode_texts(
list(cached_chunks), show_progress=False
)
_file_dim = _file_embeddings.shape[1]
_file_index = faiss.IndexFlatL2(_file_dim)
# pylint: disable=no-value-for-parameter
_file_index.add(_file_embeddings.astype("float32"))
self.file_indices[file_path] = _file_index
self.file_embeddings[file_path] = _file_embeddings
except Exception as _e: # pylint: disable=broad-except
self.log.debug(
"Couldn't pre-build per-file index for %s: %s",
file_path,
_e,
)

# Track access time for LRU (was missing — pre-existing bug)
current_time = time.time()
self.file_index_times[file_path] = current_time
Expand Down Expand Up @@ -2276,57 +2281,54 @@ def index_document(self, file_path: str) -> Dict[str, Any]:
stats["memory_limit_reached"] = True
return stats

# Track which chunks belong to this file and only publish them after
# the rebuilt global index succeeds.
if self.chunks:
# Encode once; reuse for both the global and per-file index.
if self.index is None:
if self.config.show_stats:
print("🏗️ Building initial search index...")
new_index, rebuilt_chunks, file_embeddings = (
self._create_vector_index(
list(new_chunks), return_embeddings=True
)
)
file_chunk_indices = list(range(len(new_chunks)))
rebuilt_chunk_to_file = {
idx: file_path for idx in file_chunk_indices
}
else:
old_chunks = list(self.chunks)
old_count = len(old_chunks)
rebuilt_chunks_source = old_chunks + list(new_chunks)
start_idx = len(old_chunks)
file_chunk_indices = list(
range(start_idx, start_idx + len(new_chunks))
)
rebuilt_chunks = old_chunks + list(new_chunks)
rebuilt_chunk_to_file = dict(self.chunk_to_file)
for chunk_idx in file_chunk_indices:
rebuilt_chunk_to_file[chunk_idx] = file_path

if self.config.show_stats:
print(
f"🔄 Rebuilding search index ({old_count} + {len(new_chunks)} = {len(rebuilt_chunks_source)} total chunks)"
f"➕ Appending {len(new_chunks)} new chunks to existing index ({old_count} -> {len(rebuilt_chunks)} total)"
)
else:
file_chunk_indices = list(range(len(new_chunks)))
rebuilt_chunks_source = list(new_chunks)
rebuilt_chunk_to_file = {
idx: file_path for idx in file_chunk_indices
}

if self.config.show_stats:
print("🏗️ Building initial search index...")

rebuilt_index, rebuilt_chunks = self._create_vector_index(
rebuilt_chunks_source
)
self.index = rebuilt_index
self.chunks = rebuilt_chunks
self.chunk_to_file = rebuilt_chunk_to_file
self.file_to_chunk_indices[file_path] = file_chunk_indices
self._load_embedder()
file_embeddings = self._encode_texts(
new_chunks, show_progress=False
)

# Build and cache per-file FAISS index for fast file-specific searches
if self.config.show_stats:
print("🔍 Building per-file search index...")

self._load_embedder()
# Generate embeddings for this file's chunks only
file_embeddings = self._encode_texts(new_chunks, show_progress=False)

# Create FAISS index for this file
dimension = file_embeddings.shape[1]
file_index = faiss.IndexFlatL2(dimension)
# pylint: disable=no-value-for-parameter
file_index.add(file_embeddings.astype("float32"))
file_index = self._create_faiss_index(file_embeddings)

# Cache the index and embeddings for this file
if self.index is None:
self.index = new_index
else:
# Append after the per-file index is ready.
self.index.add(file_embeddings.astype("float32"))
self.chunks = rebuilt_chunks
self.chunk_to_file = rebuilt_chunk_to_file
self.file_to_chunk_indices[file_path] = file_chunk_indices
self.file_indices[file_path] = file_index
self.file_embeddings[file_path] = file_embeddings

Expand Down
103 changes: 103 additions & 0 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,109 @@ def test_cache_functionality(self, mock_dependencies):
assert isinstance(result2, dict)
assert result2.get("success") is True

def test_indexing_two_documents_embeds_each_file_once(self, mock_dependencies):
"""Test that each indexed file is embedded once on the steady-state path."""
if not RAG_AVAILABLE:
pytest.skip(f"RAG dependencies not available: {IMPORT_ERROR}")

with tempfile.TemporaryDirectory() as temp_dir:
config = RAGConfig(cache_dir=temp_dir, show_stats=False)

with patch("gaia.rag.sdk.RAGSDK._check_dependencies"):
rag = RAGSDK(config)

doc1 = Path(temp_dir) / "doc1.txt"
doc1.write_text("doc1")
doc2 = Path(temp_dir) / "doc2.txt"
doc2.write_text("doc2")

encode_calls = []

def fake_encode(texts, show_progress=False): # noqa: ARG001
encode_calls.append(list(texts))
return np.array(
[[0.1, 0.2, 0.3, 0.4] for _ in texts], dtype=np.float32
)

with (
patch.object(
rag,
"_extract_text_from_file",
side_effect=[("doc1 text", {}), ("doc2 text", {})],
),
patch.object(
rag,
"_split_text_into_chunks",
side_effect=[["doc1 chunk"], ["doc2 chunk"]],
),
patch.object(rag, "_encode_texts", side_effect=fake_encode),
):
result1 = rag.index_document(str(doc1))
result2 = rag.index_document(str(doc2))

assert result1["success"] is True
assert result2["success"] is True
assert encode_calls == [["doc1 chunk"], ["doc2 chunk"]]

def test_cache_load_embeds_each_cached_file_once(self, mock_dependencies):
"""Test that cache loads reuse one embedding pass per cached file."""
if not RAG_AVAILABLE:
pytest.skip(f"RAG dependencies not available: {IMPORT_ERROR}")

with tempfile.TemporaryDirectory() as temp_dir:
config = RAGConfig(cache_dir=temp_dir, show_stats=False)

doc1 = Path(temp_dir) / "doc1.txt"
doc1.write_text("doc1")
doc2 = Path(temp_dir) / "doc2.txt"
doc2.write_text("doc2")

with patch("gaia.rag.sdk.RAGSDK._check_dependencies"):
seed_rag = RAGSDK(config)

with (
patch.object(
seed_rag,
"_extract_text_from_file",
side_effect=[("doc1 text", {}), ("doc2 text", {})],
),
patch.object(
seed_rag,
"_split_text_into_chunks",
side_effect=[["cached doc1 chunk"], ["cached doc2 chunk"]],
),
):
assert seed_rag.index_document(str(doc1))["success"] is True
assert seed_rag.index_document(str(doc2))["success"] is True

with patch("gaia.rag.sdk.RAGSDK._check_dependencies"):
rag = RAGSDK(config)

encode_calls = []

def fake_encode(texts, show_progress=False): # noqa: ARG001
encode_calls.append(list(texts))
return np.array(
[[0.1, 0.2, 0.3, 0.4] for _ in texts], dtype=np.float32
)

with (
patch.object(
rag,
"_extract_text_from_file",
side_effect=AssertionError("cache load should skip extraction"),
),
patch.object(rag, "_encode_texts", side_effect=fake_encode),
):
result1 = rag.index_document(str(doc1))
result2 = rag.index_document(str(doc2))

assert result1["success"] is True
assert result1["from_cache"] is True
assert result2["success"] is True
assert result2["from_cache"] is True
assert encode_calls == [["cached doc1 chunk"], ["cached doc2 chunk"]]

# NOTE: The pickle-era cache recovery tests (test_corrupted_cache_recovery,
# test_cache_checksum_mismatch_recovery, test_oversized_cache_rejected,
# test_old_format_cache_migration) were removed when the cache format
Expand Down
Loading