Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions vdb_benchmark/tests/tests/test_issue_375_chunked_insert_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
Regression tests for issue #375:

vdb benchmark shows a very low recall@10 because the flat_gt
collection size is too small.

Root cause: ``load_vdb.insert_data`` computed primary keys as
``range(batch_start, batch_end)`` based on the *chunk-local* index,
so every chunk re-used IDs ``0..chunk_size-1``. With ``num_vectors=1M``
and ``chunk_size=10k`` the source collection ended up with only 10k
unique PKs (and 99 duplicates per PK), which in turn made the
``flat_gt`` collection only 10k rows — about 1% of the source — and
drove recall@10 down to ~0.009.

The fix adds a ``start_id`` offset to ``insert_data`` and threads a
running ``global_id_offset`` through the chunked path in ``main``.
These tests verify the IDs are globally unique across chunks, and that
the legacy default (``start_id=0``) still works for the single-chunk
path.
"""
import os
import sys
from unittest.mock import MagicMock

import numpy as np
import pytest

# Make the package importable regardless of where pytest is invoked from.
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)

# We import the function under test from the real module. We do NOT import
# the module-level argparse / Milvus connect code — those run only inside
# ``main()``. The import itself is cheap.
from vdbbench.load_vdb import insert_data # noqa: E402


def _captured_ids(mock_collection):
"""Concatenate every IDs list passed to ``collection.insert``."""
captured = []
for call in mock_collection.insert.call_args_list:
# ``insert`` is called as ``collection.insert([ids, batch_vectors])``.
args, _kwargs = call
payload = args[0]
ids = payload[0]
captured.extend(list(ids))
return captured


class TestInsertDataIdOffset:
"""Verify primary-key uniqueness across chunked inserts."""

def test_default_start_id_preserves_legacy_behavior(self):
"""When ``start_id`` is omitted, IDs start at 0 — same as before #375."""
collection = MagicMock()
vectors = np.zeros((100, 8), dtype=np.float32)

total, _elapsed = insert_data(collection, vectors, batch_size=25)

assert total == 100
ids = _captured_ids(collection)
assert ids == list(range(0, 100))

def test_start_id_offsets_all_batches(self):
"""A non-zero ``start_id`` shifts every batch's IDs by that offset."""
collection = MagicMock()
vectors = np.zeros((50, 4), dtype=np.float32)

insert_data(collection, vectors, batch_size=10, start_id=1000)

ids = _captured_ids(collection)
assert ids == list(range(1000, 1050))

def test_three_chunks_produce_globally_unique_ids(self):
"""
Exact reproduction of issue #375: simulate the chunked path in
``main()`` with three chunks. Before the fix, every chunk re-used
IDs 0..chunk_size-1 and the union had only ``chunk_size`` unique
values; after the fix the union has ``3 * chunk_size`` unique values.
"""
collection = MagicMock()
chunk_size = 1000
batch_size = 100
num_chunks = 3

global_offset = 0
for _ in range(num_chunks):
chunk = np.zeros((chunk_size, 4), dtype=np.float32)
insert_data(collection, chunk, batch_size=batch_size, start_id=global_offset)
global_offset += chunk_size

ids = _captured_ids(collection)
assert len(ids) == num_chunks * chunk_size
# The critical assertion the original code would fail:
assert len(set(ids)) == num_chunks * chunk_size, (
"Duplicate primary keys across chunks — issue #375 regression."
)
assert min(ids) == 0
assert max(ids) == num_chunks * chunk_size - 1

def test_uneven_final_chunk(self):
"""The final chunk is usually smaller than ``chunk_size``."""
collection = MagicMock()
# 2500 vectors total, chunks of 1000 → 1000, 1000, 500
chunks = [1000, 1000, 500]
global_offset = 0
for n in chunks:
chunk = np.zeros((n, 4), dtype=np.float32)
insert_data(collection, chunk, batch_size=300, start_id=global_offset)
global_offset += n

ids = _captured_ids(collection)
assert ids == list(range(0, 2500))
assert len(set(ids)) == 2500

def test_batch_size_larger_than_chunk(self):
"""``batch_size`` >= len(vectors) should still produce one batch with the offset applied."""
collection = MagicMock()
vectors = np.zeros((42, 4), dtype=np.float32)

insert_data(collection, vectors, batch_size=1000, start_id=500)

assert collection.insert.call_count == 1
ids = _captured_ids(collection)
assert ids == list(range(500, 542))


class TestFlatGtCoverageGuard:
"""
Sanity-check the *intent* of the coverage guard added to
``enhanced_bench.create_flat_collection``: a flat_gt collection
that covers far fewer entities than the source should be flagged.

We assert the threshold here rather than invoking Milvus, so this
test runs in CI with no external dependencies.
"""

@pytest.mark.parametrize(
"flat_count,source_count,should_pass",
[
(1_000_000, 1_000_000, True), # exact match
(995_000, 1_000_000, True), # 99.5%, within tolerance
(10_000, 1_000_000, False), # the issue #375 failure mode
(100_000, 1_000_000, False), # only 10%, still wrong
(0, 1_000_000, False), # empty
],
)
def test_coverage_threshold(self, flat_count, source_count, should_pass):
coverage = flat_count / source_count if source_count else 0.0
passes = coverage >= 0.99
assert passes is should_pass
26 changes: 25 additions & 1 deletion vdb_benchmark/vdbbench/enhanced_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,13 @@ def create_flat_collection(
pct = min(100.0, 100.0 * copied / total_vectors)
print(f" Copied {copied}/{total_vectors} vectors ({pct:.1f}%)")

print(f" Copied {copied}/{total_vectors} vectors (100.0%)")
# Compute actual completion percentage rather than hardcoding 100%.
# When the source collection has duplicate primary keys, the
# query_iterator deduplicates and `copied` ends up well below
# `total_vectors` — printing "100.0%" here used to hide that
# (see issue #375).
final_pct = 100.0 * copied / total_vectors if total_vectors else 0.0
print(f" Copied {copied}/{total_vectors} vectors ({final_pct:.1f}%)")
flat_coll.flush()

# Wait for entity count to stabilize after flush
Expand All @@ -928,6 +934,24 @@ def create_flat_collection(
flat_coll.load()
print(f"FLAT collection '{flat_collection_name}' ready with "
f"{flat_coll.num_entities} vectors.")

# Guard: the ground-truth FLAT collection must cover the source
# collection's vectors. If it doesn't, recall@k will be artificially
# low because most ANN-returned PKs simply won't exist in the GT
# set. This was the failure mode reported in issue #375 (caused by
# duplicate PKs in the source). Fail loudly rather than silently
# producing meaningless recall numbers.
coverage = (flat_coll.num_entities / total_vectors) if total_vectors else 0.0
if coverage < 0.99:
print(f"ERROR: FLAT ground-truth collection covers only "
f"{flat_coll.num_entities}/{total_vectors} "
f"({coverage * 100:.2f}%) of the source collection. "
f"This will produce artificially low recall@k. "
f"Common cause: duplicate primary keys in the source "
f"collection (see issue #375). "
f"Re-run the load step and verify unique PKs.")
return False

return True

except Exception as e:
Expand Down
47 changes: 37 additions & 10 deletions vdb_benchmark/vdbbench/load_vdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,23 @@ def generate_vectors(num_vectors, dim, distribution='uniform'):
return vectors


def insert_data(collection, vectors, batch_size=10000):
"""Insert vectors into the collection in batches"""
def insert_data(collection, vectors, batch_size=10000, start_id=0):
"""Insert vectors into the collection in batches.

Args:
collection: The Milvus collection to insert into.
vectors: A numpy array (or list) of vectors to insert.
batch_size: Number of vectors to send per insert() call.
start_id: Global ID offset for the primary key. When the caller is
generating data in chunks and calling ``insert_data`` once per
chunk, ``start_id`` MUST be set to the number of vectors already
inserted across previous chunks. Without this offset, every chunk
would re-use primary keys 0..len(vectors)-1, producing duplicate
PKs (see issue #375).

Returns:
Tuple of (total_inserted, elapsed_seconds).
"""
total_vectors = len(vectors)
num_batches = (total_vectors + batch_size - 1) // batch_size

Expand All @@ -210,8 +225,9 @@ def insert_data(collection, vectors, batch_size=10000):
batch_end = min((i + 1) * batch_size, total_vectors)
batch_size_actual = batch_end - batch_start

# Prepare batch data
ids = list(range(batch_start, batch_end))
# Prepare batch data — IDs must be globally unique across chunks,
# hence the start_id offset (see docstring + issue #375).
ids = list(range(start_id + batch_start, start_id + batch_end))
batch_vectors = vectors[batch_start:batch_end]

# Insert batch
Expand All @@ -225,7 +241,8 @@ def insert_data(collection, vectors, batch_size=10000):
rate = total_inserted / elapsed if elapsed > 0 else 0

logger.info(f"Inserted batch {i+1}/{num_batches}: {progress:.2f}% complete, "
f"rate: {rate:.2f} vectors/sec")
f"rate: {rate:.2f} vectors/sec, "
f"id_range=[{ids[0]}, {ids[-1]}]")

except Exception as e:
logger.error(f"Error inserting batch {i+1}: {str(e)}")
Expand Down Expand Up @@ -332,7 +349,11 @@ def main():
vectors = []
remaining = args.num_vectors
chunks_processed = 0

# Running global offset for primary keys. This is the FIX for
# issue #375: without it each chunk re-inserts IDs 0..chunk_size-1,
# producing duplicate PKs and a too-small flat_gt collection.
global_id_offset = 0

while remaining > 0:
chunk_size = min(args.chunk_size, remaining)
logger.info(f"Generating chunk {chunks_processed+1}: {chunk_size:,} vectors")
Expand All @@ -344,19 +365,25 @@ def main():
f"Progress: {(args.num_vectors - remaining):,}/{args.num_vectors:,} vectors "
f"({(args.num_vectors - remaining) / args.num_vectors * 100:.1f}%)")

# Insert data
logger.info(f"Inserting {args.num_vectors} vectors into collection '{args.collection_name}'")
total_inserted, insert_time = insert_data(collection, chunk_vectors, args.batch_size)
# Insert data — pass the running global_id_offset so PKs are
# unique across chunks.
logger.info(f"Inserting chunk {chunks_processed+1} ({chunk_size:,} vectors) into "
f"collection '{args.collection_name}' starting at id={global_id_offset}")
total_inserted, insert_time = insert_data(
collection, chunk_vectors, args.batch_size,
start_id=global_id_offset,
)
logger.info(f"Inserted {total_inserted} vectors in {insert_time:.2f} seconds")

global_id_offset += chunk_size
remaining -= chunk_size
chunks_processed += 1
else:
# For smaller vector counts, generate all at once
vectors = generate_vectors(args.num_vectors, args.dimension, args.distribution)
# Insert data
logger.info(f"Inserting {args.num_vectors} vectors into collection '{args.collection_name}'")
total_inserted, insert_time = insert_data(collection, vectors, args.batch_size)
total_inserted, insert_time = insert_data(collection, vectors, args.batch_size, start_id=0)
logger.info(f"Inserted {total_inserted} vectors in {insert_time:.2f} seconds")

gen_time = time.time() - start_gen_time
Expand Down
Loading