Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["."]
include = ["vectordb_bench", "vectordb_bench.cli"]
include = ["vectordb_bench", "vectordb_bench.*"]

[project]
name = "vectordb-bench"
Expand All @@ -26,6 +26,7 @@ classifiers = [
]
dependencies = [
"click",
"pyyaml",
"pytz",
"streamlit>=1.47,<2", # 1.47 fixes streamlit#11660
"tqdm",
Expand Down
38 changes: 37 additions & 1 deletion vectordb_bench/backend/clients/oceanbase/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,40 @@ class OceanBaseTypedDict(CommonTypedDict):
]
database: Annotated[str, click.option("--database", type=str, help="DataBase name", required=True)]
port: Annotated[int, click.option("--port", type=int, help="OceanBase port", required=True)]
create_index_parallel: Annotated[
int,
click.option(
"--create-index-parallel",
type=int,
default=16,
show_default=True,
help="PARALLEL hint degree for CREATE VECTOR INDEX",
),
]
partitions: Annotated[
int,
click.option(
"--partitions",
type=int,
default=0,
show_default=True,
help="Number of KEY partitions for the table. 0 or 1 means no partitioning.",
),
]


class OceanBaseHNSWTypedDict(CommonTypedDict, OceanBaseTypedDict, HNSWFlavor4): ...
class OceanBaseHNSWTypedDict(CommonTypedDict, OceanBaseTypedDict, HNSWFlavor4):
extra_info_max_size: Annotated[
int | None,
click.option(
"--extra-info-max-size",
type=int,
default=32,
show_default=True,
help="extra_info_max_size for HNSW index. Set to 0 to omit.",
required=False,
),
]


@cli.command()
Expand All @@ -55,6 +86,9 @@ def OceanBaseHNSW(**parameters: Unpack[OceanBaseHNSWTypedDict]):
m=parameters["m"],
efConstruction=parameters["ef_construction"],
ef_search=parameters["ef_search"],
extra_info_max_size=parameters["extra_info_max_size"] or None,
create_index_parallel=parameters["create_index_parallel"],
partitions=parameters["partitions"],
index=parameters["index_type"],
),
**parameters,
Expand Down Expand Up @@ -94,6 +128,8 @@ def OceanBaseIVF(**parameters: Unpack[OceanBaseIVFTypedDict]):
nlist=parameters["nlist"],
sample_per_nlist=parameters["sample_per_nlist"],
nbits=parameters["nbits"],
create_index_parallel=parameters["create_index_parallel"],
partitions=parameters["partitions"],
index=input_index_type,
ivf_nprobes=parameters["ivf_nprobes"],
),
Expand Down
11 changes: 5 additions & 6 deletions vectordb_bench/backend/clients/oceanbase/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,18 @@ class OceanBaseIndexConfig(BaseModel):
index: IndexType
metric_type: MetricType | None = None
lib: str = "vsag"
create_index_parallel: int = 16
partitions: int = 0

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2 or (
self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE
):
if self.metric_type == MetricType.L2:
return "l2"
if self.metric_type == MetricType.IP:
return "inner_product"
return "cosine"

def parse_metric_func_str(self) -> str:
if self.metric_type == MetricType.L2 or (
self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE
):
if self.metric_type == MetricType.L2:
return "l2_distance"
if self.metric_type == MetricType.IP:
return "negative_inner_product"
Expand All @@ -60,6 +58,7 @@ class OceanBaseHNSWConfig(OceanBaseIndexConfig, DBCaseConfig):
m: int
efConstruction: int
ef_search: int | None = None
extra_info_max_size: int | None = 32
index: IndexType

def index_param(self) -> dict:
Expand Down
40 changes: 23 additions & 17 deletions vectordb_bench/backend/clients/oceanbase/oceanbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class OceanBase(VectorDB):
FilterOp.NumGE,
FilterOp.StrEqual,
]
# insert path is GIL-bound; multi-threading cannot improve load throughput
thread_safe: bool = False

def __init__(
self,
Expand Down Expand Up @@ -58,6 +60,15 @@ def __init__(
finally:
self._disconnect()

def __getstate__(self):
state = self.__dict__.copy()
state["_conn"] = None
state["_cursor"] = None
return state

def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)

def _connect(self):
try:
self._conn = mysql.connect(
Expand Down Expand Up @@ -109,27 +120,28 @@ def _create_table(self):
if not self._cursor:
raise ValueError("Cursor is not initialized")

log.info(f"Creating table {self.table_name}")
create_table_query = f"""
CREATE TABLE {self.table_name} (
id INT PRIMARY KEY,
embedding VECTOR({self.dim})
);
"""
partitions = getattr(self.db_case_config, "partitions", 0)
log.info(f"Creating table {self.table_name} (partitions={partitions})")

create_table_query = f"CREATE TABLE {self.table_name} (id INT PRIMARY KEY, embedding VECTOR({self.dim}))"
if partitions > 1:
create_table_query += f" PARTITION BY KEY(id) PARTITIONS {partitions}"
create_table_query += ";"
self._cursor.execute(create_table_query)

def optimize(self, data_size: int):
index_params = self.db_case_config.index_param()
index_args = ", ".join(f"{k}={v}" for k, v in index_params["params"].items())
index_query = (
f"CREATE /*+ PARALLEL(18) */ VECTOR INDEX idx1 "
f"CREATE /*+ PARALLEL({self.db_case_config.create_index_parallel}) */ VECTOR INDEX idx1 "
f"ON {self.table_name}(embedding) "
f"WITH (distance={self.db_case_config.parse_metric()}, "
f"type={index_params['index_type']}, lib={index_params['lib']}, {index_args}"
)

if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}:
index_query += ", extra_info_max_size=32"
extra_info = getattr(self.db_case_config, "extra_info_max_size", None)
if extra_info is not None:
index_query += f", extra_info_max_size={extra_info}"

index_query += ")"

Expand All @@ -153,10 +165,6 @@ def optimize(self, data_size: int):
raise

def need_normalize_cosine(self) -> bool:
if self.db_case_config.index == IndexType.HNSW_BQ:
log.info("current HNSW_BQ only supports L2, cosine dataset need normalize.")
return True

return False

def _wait_for_major_compaction(self):
Expand Down Expand Up @@ -185,9 +193,7 @@ def insert_embeddings(
batch_end = min(batch_start + self.load_batch_size, len(embeddings))
batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)]
values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch)
self._cursor.execute(
f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}"
)
self._cursor.execute(f"INSERT INTO {self.table_name} VALUES {values}")
insert_count += len(batch)
except mysql.Error:
log.exception("Failed to insert embeddings")
Expand Down
2 changes: 2 additions & 0 deletions vectordb_bench/backend/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .concurrent_runner import ConcurrentInsertRunner
from .mp_runner import MultiProcessingSearchRunner
from .multiprocess_load_runner import MultiprocessInsertRunner
from .read_write_runner import ReadWriteRunner
from .serial_runner import SerialInsertRunner, SerialSearchRunner

__all__ = [
"ConcurrentInsertRunner",
"MultiProcessingSearchRunner",
"MultiprocessInsertRunner",
"ReadWriteRunner",
"SerialInsertRunner",
"SerialSearchRunner",
Expand Down
5 changes: 4 additions & 1 deletion vectordb_bench/backend/runner/concurrent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def __init__(

effective_workers = max_workers or min(mp.cpu_count(), 4)
if not db.thread_safe:
log.info(f"DB {db.name} is not thread-safe, falling back to max_workers=1")
log.info(
f"DB {db.name} declared thread_safe=False, falling back to max_workers=1"
" (use --load-processes for parallel loading)"
)
effective_workers = 1
self.max_workers = effective_workers
assert db.thread_safe or self.max_workers == 1, (
Expand Down
Loading
Loading