diff --git a/vdb_benchmark/tests/tests/test_issue_375_chunked_insert_ids.py b/vdb_benchmark/tests/tests/test_issue_375_chunked_insert_ids.py new file mode 100644 index 00000000..e1c23636 --- /dev/null +++ b/vdb_benchmark/tests/tests/test_issue_375_chunked_insert_ids.py @@ -0,0 +1,152 @@ +""" +Regression tests for issue #375: + + vdb benchmark shows a very low recall@10 because the flat_gt + collection size is too small. + +Root cause: ``load_vdb.insert_data`` computed primary keys as +``range(batch_start, batch_end)`` based on the *chunk-local* index, +so every chunk re-used IDs ``0..chunk_size-1``. With ``num_vectors=1M`` +and ``chunk_size=10k`` the source collection ended up with only 10k +unique PKs (and 99 duplicates per PK), which in turn made the +``flat_gt`` collection only 10k rows — about 1% of the source — and +drove recall@10 down to ~0.009. + +The fix adds a ``start_id`` offset to ``insert_data`` and threads a +running ``global_id_offset`` through the chunked path in ``main``. +These tests verify the IDs are globally unique across chunks, and that +the legacy default (``start_id=0``) still works for the single-chunk +path. +""" +import os +import sys +from unittest.mock import MagicMock + +import numpy as np +import pytest + +# Make the package importable regardless of where pytest is invoked from. +ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if ROOT not in sys.path: + sys.path.insert(0, ROOT) + +# We import the function under test from the real module. We do NOT import +# the module-level argparse / Milvus connect code — those run only inside +# ``main()``. The import itself is cheap. +from vdbbench.load_vdb import insert_data # noqa: E402 + + +def _captured_ids(mock_collection): + """Concatenate every IDs list passed to ``collection.insert``.""" + captured = [] + for call in mock_collection.insert.call_args_list: + # ``insert`` is called as ``collection.insert([ids, batch_vectors])``. + args, _kwargs = call + payload = args[0] + ids = payload[0] + captured.extend(list(ids)) + return captured + + +class TestInsertDataIdOffset: + """Verify primary-key uniqueness across chunked inserts.""" + + def test_default_start_id_preserves_legacy_behavior(self): + """When ``start_id`` is omitted, IDs start at 0 — same as before #375.""" + collection = MagicMock() + vectors = np.zeros((100, 8), dtype=np.float32) + + total, _elapsed = insert_data(collection, vectors, batch_size=25) + + assert total == 100 + ids = _captured_ids(collection) + assert ids == list(range(0, 100)) + + def test_start_id_offsets_all_batches(self): + """A non-zero ``start_id`` shifts every batch's IDs by that offset.""" + collection = MagicMock() + vectors = np.zeros((50, 4), dtype=np.float32) + + insert_data(collection, vectors, batch_size=10, start_id=1000) + + ids = _captured_ids(collection) + assert ids == list(range(1000, 1050)) + + def test_three_chunks_produce_globally_unique_ids(self): + """ + Exact reproduction of issue #375: simulate the chunked path in + ``main()`` with three chunks. Before the fix, every chunk re-used + IDs 0..chunk_size-1 and the union had only ``chunk_size`` unique + values; after the fix the union has ``3 * chunk_size`` unique values. + """ + collection = MagicMock() + chunk_size = 1000 + batch_size = 100 + num_chunks = 3 + + global_offset = 0 + for _ in range(num_chunks): + chunk = np.zeros((chunk_size, 4), dtype=np.float32) + insert_data(collection, chunk, batch_size=batch_size, start_id=global_offset) + global_offset += chunk_size + + ids = _captured_ids(collection) + assert len(ids) == num_chunks * chunk_size + # The critical assertion the original code would fail: + assert len(set(ids)) == num_chunks * chunk_size, ( + "Duplicate primary keys across chunks — issue #375 regression." + ) + assert min(ids) == 0 + assert max(ids) == num_chunks * chunk_size - 1 + + def test_uneven_final_chunk(self): + """The final chunk is usually smaller than ``chunk_size``.""" + collection = MagicMock() + # 2500 vectors total, chunks of 1000 → 1000, 1000, 500 + chunks = [1000, 1000, 500] + global_offset = 0 + for n in chunks: + chunk = np.zeros((n, 4), dtype=np.float32) + insert_data(collection, chunk, batch_size=300, start_id=global_offset) + global_offset += n + + ids = _captured_ids(collection) + assert ids == list(range(0, 2500)) + assert len(set(ids)) == 2500 + + def test_batch_size_larger_than_chunk(self): + """``batch_size`` >= len(vectors) should still produce one batch with the offset applied.""" + collection = MagicMock() + vectors = np.zeros((42, 4), dtype=np.float32) + + insert_data(collection, vectors, batch_size=1000, start_id=500) + + assert collection.insert.call_count == 1 + ids = _captured_ids(collection) + assert ids == list(range(500, 542)) + + +class TestFlatGtCoverageGuard: + """ + Sanity-check the *intent* of the coverage guard added to + ``enhanced_bench.create_flat_collection``: a flat_gt collection + that covers far fewer entities than the source should be flagged. + + We assert the threshold here rather than invoking Milvus, so this + test runs in CI with no external dependencies. + """ + + @pytest.mark.parametrize( + "flat_count,source_count,should_pass", + [ + (1_000_000, 1_000_000, True), # exact match + (995_000, 1_000_000, True), # 99.5%, within tolerance + (10_000, 1_000_000, False), # the issue #375 failure mode + (100_000, 1_000_000, False), # only 10%, still wrong + (0, 1_000_000, False), # empty + ], + ) + def test_coverage_threshold(self, flat_count, source_count, should_pass): + coverage = flat_count / source_count if source_count else 0.0 + passes = coverage >= 0.99 + assert passes is should_pass diff --git a/vdb_benchmark/vdbbench/enhanced_bench.py b/vdb_benchmark/vdbbench/enhanced_bench.py index f813ef79..e7bee4ca 100755 --- a/vdb_benchmark/vdbbench/enhanced_bench.py +++ b/vdb_benchmark/vdbbench/enhanced_bench.py @@ -903,7 +903,13 @@ def create_flat_collection( pct = min(100.0, 100.0 * copied / total_vectors) print(f" Copied {copied}/{total_vectors} vectors ({pct:.1f}%)") - print(f" Copied {copied}/{total_vectors} vectors (100.0%)") + # Compute actual completion percentage rather than hardcoding 100%. + # When the source collection has duplicate primary keys, the + # query_iterator deduplicates and `copied` ends up well below + # `total_vectors` — printing "100.0%" here used to hide that + # (see issue #375). + final_pct = 100.0 * copied / total_vectors if total_vectors else 0.0 + print(f" Copied {copied}/{total_vectors} vectors ({final_pct:.1f}%)") flat_coll.flush() # Wait for entity count to stabilize after flush @@ -928,6 +934,24 @@ def create_flat_collection( flat_coll.load() print(f"FLAT collection '{flat_collection_name}' ready with " f"{flat_coll.num_entities} vectors.") + + # Guard: the ground-truth FLAT collection must cover the source + # collection's vectors. If it doesn't, recall@k will be artificially + # low because most ANN-returned PKs simply won't exist in the GT + # set. This was the failure mode reported in issue #375 (caused by + # duplicate PKs in the source). Fail loudly rather than silently + # producing meaningless recall numbers. + coverage = (flat_coll.num_entities / total_vectors) if total_vectors else 0.0 + if coverage < 0.99: + print(f"ERROR: FLAT ground-truth collection covers only " + f"{flat_coll.num_entities}/{total_vectors} " + f"({coverage * 100:.2f}%) of the source collection. " + f"This will produce artificially low recall@k. " + f"Common cause: duplicate primary keys in the source " + f"collection (see issue #375). " + f"Re-run the load step and verify unique PKs.") + return False + return True except Exception as e: diff --git a/vdb_benchmark/vdbbench/load_vdb.py b/vdb_benchmark/vdbbench/load_vdb.py index 19444d32..89879b87 100755 --- a/vdb_benchmark/vdbbench/load_vdb.py +++ b/vdb_benchmark/vdbbench/load_vdb.py @@ -197,8 +197,23 @@ def generate_vectors(num_vectors, dim, distribution='uniform'): return vectors -def insert_data(collection, vectors, batch_size=10000): - """Insert vectors into the collection in batches""" +def insert_data(collection, vectors, batch_size=10000, start_id=0): + """Insert vectors into the collection in batches. + + Args: + collection: The Milvus collection to insert into. + vectors: A numpy array (or list) of vectors to insert. + batch_size: Number of vectors to send per insert() call. + start_id: Global ID offset for the primary key. When the caller is + generating data in chunks and calling ``insert_data`` once per + chunk, ``start_id`` MUST be set to the number of vectors already + inserted across previous chunks. Without this offset, every chunk + would re-use primary keys 0..len(vectors)-1, producing duplicate + PKs (see issue #375). + + Returns: + Tuple of (total_inserted, elapsed_seconds). + """ total_vectors = len(vectors) num_batches = (total_vectors + batch_size - 1) // batch_size @@ -210,8 +225,9 @@ def insert_data(collection, vectors, batch_size=10000): batch_end = min((i + 1) * batch_size, total_vectors) batch_size_actual = batch_end - batch_start - # Prepare batch data - ids = list(range(batch_start, batch_end)) + # Prepare batch data — IDs must be globally unique across chunks, + # hence the start_id offset (see docstring + issue #375). + ids = list(range(start_id + batch_start, start_id + batch_end)) batch_vectors = vectors[batch_start:batch_end] # Insert batch @@ -225,7 +241,8 @@ def insert_data(collection, vectors, batch_size=10000): rate = total_inserted / elapsed if elapsed > 0 else 0 logger.info(f"Inserted batch {i+1}/{num_batches}: {progress:.2f}% complete, " - f"rate: {rate:.2f} vectors/sec") + f"rate: {rate:.2f} vectors/sec, " + f"id_range=[{ids[0]}, {ids[-1]}]") except Exception as e: logger.error(f"Error inserting batch {i+1}: {str(e)}") @@ -332,7 +349,11 @@ def main(): vectors = [] remaining = args.num_vectors chunks_processed = 0 - + # Running global offset for primary keys. This is the FIX for + # issue #375: without it each chunk re-inserts IDs 0..chunk_size-1, + # producing duplicate PKs and a too-small flat_gt collection. + global_id_offset = 0 + while remaining > 0: chunk_size = min(args.chunk_size, remaining) logger.info(f"Generating chunk {chunks_processed+1}: {chunk_size:,} vectors") @@ -344,11 +365,17 @@ def main(): f"Progress: {(args.num_vectors - remaining):,}/{args.num_vectors:,} vectors " f"({(args.num_vectors - remaining) / args.num_vectors * 100:.1f}%)") - # Insert data - logger.info(f"Inserting {args.num_vectors} vectors into collection '{args.collection_name}'") - total_inserted, insert_time = insert_data(collection, chunk_vectors, args.batch_size) + # Insert data — pass the running global_id_offset so PKs are + # unique across chunks. + logger.info(f"Inserting chunk {chunks_processed+1} ({chunk_size:,} vectors) into " + f"collection '{args.collection_name}' starting at id={global_id_offset}") + total_inserted, insert_time = insert_data( + collection, chunk_vectors, args.batch_size, + start_id=global_id_offset, + ) logger.info(f"Inserted {total_inserted} vectors in {insert_time:.2f} seconds") + global_id_offset += chunk_size remaining -= chunk_size chunks_processed += 1 else: @@ -356,7 +383,7 @@ def main(): vectors = generate_vectors(args.num_vectors, args.dimension, args.distribution) # Insert data logger.info(f"Inserting {args.num_vectors} vectors into collection '{args.collection_name}'") - total_inserted, insert_time = insert_data(collection, vectors, args.batch_size) + total_inserted, insert_time = insert_data(collection, vectors, args.batch_size, start_id=0) logger.info(f"Inserted {total_inserted} vectors in {insert_time:.2f} seconds") gen_time = time.time() - start_gen_time