diff --git a/src/gaia/rag/sdk.py b/src/gaia/rag/sdk.py index 5accbdd53..3efd2e86a 100644 --- a/src/gaia/rag/sdk.py +++ b/src/gaia/rag/sdk.py @@ -1600,7 +1600,17 @@ def _get_last_n_tokens(self, text: str, n_tokens: int) -> str: return trimmed[first_space + 1 :] return trimmed - def _create_vector_index(self, chunks: List[str]) -> tuple: + def _create_faiss_index(self, embeddings: "np.ndarray"): + """Build a FAISS L2 index from precomputed embeddings.""" + dimension = embeddings.shape[1] + index = faiss.IndexFlatL2(dimension) + # pylint: disable=no-value-for-parameter + index.add(embeddings.astype("float32")) + return index + + def _create_vector_index( + self, chunks: List[str], return_embeddings: bool = False + ) -> tuple: """Create FAISS vector index from chunks with progress reporting.""" import time as time_module # pylint: disable=reimported @@ -1635,10 +1645,7 @@ def _create_vector_index(self, chunks: List[str]) -> tuple: print("\n 🏗️ Building FAISS search index...") index_start = time_module.time() - dimension = embeddings.shape[1] - index = faiss.IndexFlatL2(dimension) - # pylint: disable=no-value-for-parameter - index.add(embeddings.astype("float32")) + index = self._create_faiss_index(embeddings) index_duration = time_module.time() - index_start if self.config.show_stats: @@ -1654,6 +1661,8 @@ def _create_vector_index(self, chunks: List[str]) -> tuple: f"📚 Index ready with {index.ntotal} vectors " f"(embed: {embed_duration:.2f}s, index: {index_duration:.2f}s)" ) + if return_embeddings: + return index, chunks, embeddings return index, chunks def _snapshot_query_state(self) -> Dict[str, Any]: @@ -2107,10 +2116,13 @@ def index_document(self, file_path: str) -> Dict[str, Any]: stats["memory_limit_reached"] = True return stats - # Track chunk indices for this file and publish only after the - # rebuilt index succeeds so partial state is never visible. + # Encode once; reuse for both the global and per-file index. if self.index is None: - rebuilt_chunks_source = list(cached_chunks) + new_index, rebuilt_chunks, file_embeddings = ( + self._create_vector_index( + list(cached_chunks), return_embeddings=True + ) + ) file_chunk_indices = list(range(len(cached_chunks))) rebuilt_chunk_to_file = { idx: file_path for idx in file_chunk_indices @@ -2118,27 +2130,37 @@ def index_document(self, file_path: str) -> Dict[str, Any]: else: old_chunks = list(self.chunks) old_count = len(old_chunks) - rebuilt_chunks_source = old_chunks + list(cached_chunks) start_idx = len(old_chunks) file_chunk_indices = list( range(start_idx, start_idx + len(cached_chunks)) ) + rebuilt_chunks = old_chunks + list(cached_chunks) rebuilt_chunk_to_file = dict(self.chunk_to_file) for chunk_idx in file_chunk_indices: rebuilt_chunk_to_file[chunk_idx] = file_path if self.config.show_stats: print( - f" 🔄 Rebuilding index ({old_count} + {len(cached_chunks)} = {len(rebuilt_chunks_source)} chunks)" + f" ➕ Appending {len(cached_chunks)} cached chunks to existing index ({old_count} -> {len(rebuilt_chunks)} total)" ) + self._load_embedder() + file_embeddings = self._encode_texts( + list(cached_chunks), show_progress=False + ) - rebuilt_index, rebuilt_chunks = self._create_vector_index( - rebuilt_chunks_source - ) - self.index = rebuilt_index + file_index = self._create_faiss_index(file_embeddings) + + if self.index is None: + self.index = new_index + else: + # Append after the per-file index is ready. + self.index.add(file_embeddings.astype("float32")) self.chunks = rebuilt_chunks self.chunk_to_file = rebuilt_chunk_to_file self.file_to_chunk_indices[file_path] = file_chunk_indices + self.file_indices[file_path] = file_index + self.file_embeddings[file_path] = file_embeddings + # Restore metadata in memory if cached_full_text or cached_metadata: self.file_metadata[file_path] = { @@ -2148,32 +2170,6 @@ def index_document(self, file_path: str) -> Dict[str, Any]: self.indexed_files.add(file_path) - # Build per-file FAISS index NOW so retrieval-time queries - # don't have to rebuild it on every call. Pre-#1030 - # follow-up: this block was only present on the - # fresh-index path (~line 2289 below), so any document - # loaded from cache hit ``cached_file_index is None`` in - # _retrieve_chunks_from_file and rebuilt the FAISS index - # from scratch on every query — adding ~3 s × N queries - # of avoidable work. One-time cost on cache load instead. - try: - self._load_embedder() - _file_embeddings = self._encode_texts( - list(cached_chunks), show_progress=False - ) - _file_dim = _file_embeddings.shape[1] - _file_index = faiss.IndexFlatL2(_file_dim) - # pylint: disable=no-value-for-parameter - _file_index.add(_file_embeddings.astype("float32")) - self.file_indices[file_path] = _file_index - self.file_embeddings[file_path] = _file_embeddings - except Exception as _e: # pylint: disable=broad-except - self.log.debug( - "Couldn't pre-build per-file index for %s: %s", - file_path, - _e, - ) - # Track access time for LRU (was missing — pre-existing bug) current_time = time.time() self.file_index_times[file_path] = current_time @@ -2276,57 +2272,54 @@ def index_document(self, file_path: str) -> Dict[str, Any]: stats["memory_limit_reached"] = True return stats - # Track which chunks belong to this file and only publish them after - # the rebuilt global index succeeds. - if self.chunks: + # Encode once; reuse for both the global and per-file index. + if self.index is None: + if self.config.show_stats: + print("🏗️ Building initial search index...") + new_index, rebuilt_chunks, file_embeddings = ( + self._create_vector_index( + list(new_chunks), return_embeddings=True + ) + ) + file_chunk_indices = list(range(len(new_chunks))) + rebuilt_chunk_to_file = { + idx: file_path for idx in file_chunk_indices + } + else: old_chunks = list(self.chunks) old_count = len(old_chunks) - rebuilt_chunks_source = old_chunks + list(new_chunks) start_idx = len(old_chunks) file_chunk_indices = list( range(start_idx, start_idx + len(new_chunks)) ) + rebuilt_chunks = old_chunks + list(new_chunks) rebuilt_chunk_to_file = dict(self.chunk_to_file) for chunk_idx in file_chunk_indices: rebuilt_chunk_to_file[chunk_idx] = file_path if self.config.show_stats: print( - f"🔄 Rebuilding search index ({old_count} + {len(new_chunks)} = {len(rebuilt_chunks_source)} total chunks)" + f"➕ Appending {len(new_chunks)} new chunks to existing index ({old_count} -> {len(rebuilt_chunks)} total)" ) - else: - file_chunk_indices = list(range(len(new_chunks))) - rebuilt_chunks_source = list(new_chunks) - rebuilt_chunk_to_file = { - idx: file_path for idx in file_chunk_indices - } - if self.config.show_stats: - print("🏗️ Building initial search index...") - - rebuilt_index, rebuilt_chunks = self._create_vector_index( - rebuilt_chunks_source - ) - self.index = rebuilt_index - self.chunks = rebuilt_chunks - self.chunk_to_file = rebuilt_chunk_to_file - self.file_to_chunk_indices[file_path] = file_chunk_indices + self._load_embedder() + file_embeddings = self._encode_texts( + new_chunks, show_progress=False + ) - # Build and cache per-file FAISS index for fast file-specific searches if self.config.show_stats: print("🔍 Building per-file search index...") - self._load_embedder() - # Generate embeddings for this file's chunks only - file_embeddings = self._encode_texts(new_chunks, show_progress=False) - - # Create FAISS index for this file - dimension = file_embeddings.shape[1] - file_index = faiss.IndexFlatL2(dimension) - # pylint: disable=no-value-for-parameter - file_index.add(file_embeddings.astype("float32")) + file_index = self._create_faiss_index(file_embeddings) - # Cache the index and embeddings for this file + if self.index is None: + self.index = new_index + else: + # Append after the per-file index is ready. + self.index.add(file_embeddings.astype("float32")) + self.chunks = rebuilt_chunks + self.chunk_to_file = rebuilt_chunk_to_file + self.file_to_chunk_indices[file_path] = file_chunk_indices self.file_indices[file_path] = file_index self.file_embeddings[file_path] = file_embeddings diff --git a/tests/test_rag.py b/tests/test_rag.py index 73ab4ef2a..ae9a9c3b1 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -380,6 +380,109 @@ def test_cache_functionality(self, mock_dependencies): assert isinstance(result2, dict) assert result2.get("success") is True + def test_indexing_two_documents_embeds_each_file_once(self, mock_dependencies): + """Test that each indexed file is embedded once on the steady-state path.""" + if not RAG_AVAILABLE: + pytest.skip(f"RAG dependencies not available: {IMPORT_ERROR}") + + with tempfile.TemporaryDirectory() as temp_dir: + config = RAGConfig(cache_dir=temp_dir, show_stats=False) + + with patch("gaia.rag.sdk.RAGSDK._check_dependencies"): + rag = RAGSDK(config) + + doc1 = Path(temp_dir) / "doc1.txt" + doc1.write_text("doc1") + doc2 = Path(temp_dir) / "doc2.txt" + doc2.write_text("doc2") + + encode_calls = [] + + def fake_encode(texts, show_progress=False): # noqa: ARG001 + encode_calls.append(list(texts)) + return np.array( + [[0.1, 0.2, 0.3, 0.4] for _ in texts], dtype=np.float32 + ) + + with ( + patch.object( + rag, + "_extract_text_from_file", + side_effect=[("doc1 text", {}), ("doc2 text", {})], + ), + patch.object( + rag, + "_split_text_into_chunks", + side_effect=[["doc1 chunk"], ["doc2 chunk"]], + ), + patch.object(rag, "_encode_texts", side_effect=fake_encode), + ): + result1 = rag.index_document(str(doc1)) + result2 = rag.index_document(str(doc2)) + + assert result1["success"] is True + assert result2["success"] is True + assert encode_calls == [["doc1 chunk"], ["doc2 chunk"]] + + def test_cache_load_embeds_each_cached_file_once(self, mock_dependencies): + """Test that cache loads reuse one embedding pass per cached file.""" + if not RAG_AVAILABLE: + pytest.skip(f"RAG dependencies not available: {IMPORT_ERROR}") + + with tempfile.TemporaryDirectory() as temp_dir: + config = RAGConfig(cache_dir=temp_dir, show_stats=False) + + doc1 = Path(temp_dir) / "doc1.txt" + doc1.write_text("doc1") + doc2 = Path(temp_dir) / "doc2.txt" + doc2.write_text("doc2") + + with patch("gaia.rag.sdk.RAGSDK._check_dependencies"): + seed_rag = RAGSDK(config) + + with ( + patch.object( + seed_rag, + "_extract_text_from_file", + side_effect=[("doc1 text", {}), ("doc2 text", {})], + ), + patch.object( + seed_rag, + "_split_text_into_chunks", + side_effect=[["cached doc1 chunk"], ["cached doc2 chunk"]], + ), + ): + assert seed_rag.index_document(str(doc1))["success"] is True + assert seed_rag.index_document(str(doc2))["success"] is True + + with patch("gaia.rag.sdk.RAGSDK._check_dependencies"): + rag = RAGSDK(config) + + encode_calls = [] + + def fake_encode(texts, show_progress=False): # noqa: ARG001 + encode_calls.append(list(texts)) + return np.array( + [[0.1, 0.2, 0.3, 0.4] for _ in texts], dtype=np.float32 + ) + + with ( + patch.object( + rag, + "_extract_text_from_file", + side_effect=AssertionError("cache load should skip extraction"), + ), + patch.object(rag, "_encode_texts", side_effect=fake_encode), + ): + result1 = rag.index_document(str(doc1)) + result2 = rag.index_document(str(doc2)) + + assert result1["success"] is True + assert result1["from_cache"] is True + assert result2["success"] is True + assert result2["from_cache"] is True + assert encode_calls == [["cached doc1 chunk"], ["cached doc2 chunk"]] + # NOTE: The pickle-era cache recovery tests (test_corrupted_cache_recovery, # test_cache_checksum_mismatch_recovery, test_oversized_cache_rejected, # test_old_format_cache_migration) were removed when the cache format