Skip to content

Commit f3b0bff

Browse files
authored
test: Integration tests for upserting with shard keys (#32)
1 parent 24b09ef commit f3b0bff

File tree

2 files changed

+96
-19
lines changed

2 files changed

+96
-19
lines changed

src/test/python/conftest.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import pytest
2-
from testcontainers.qdrant import QdrantContainer
2+
from testcontainers.core.container import DockerContainer
33
from qdrant_client import QdrantClient, models
44
import uuid
55
from pyspark.sql import SparkSession
66
from typing import NamedTuple
77
from uuid import uuid4
8+
from testcontainers.core.waiting_utils import wait_for_logs
89

910

1011
QDRANT_GRPC_PORT = 6334
1112
QDRANT_EMBEDDING_DIM = 6
1213
QDRANT_DISTANCE = models.Distance.COSINE
1314
QDRANT_API_KEY = uuid4().hex
15+
STRING_SHARD_KEY = "string_shard_key"
16+
INTEGER_SHARD_KEY = 876
1417

1518

1619
class Qdrant(NamedTuple):
@@ -20,7 +23,15 @@ class Qdrant(NamedTuple):
2023
client: QdrantClient
2124

2225

23-
qdrant_container = QdrantContainer(image="qdrant/qdrant:latest", api_key=QDRANT_API_KEY)
26+
qdrant_container = (
27+
(
28+
DockerContainer(image="qdrant/qdrant:latest")
29+
.with_env("QDRANT__SERVICE__API_KEY", QDRANT_API_KEY)
30+
.with_env("QDRANT__CLUSTER__ENABLED", "true")
31+
)
32+
.with_command("./qdrant --uri http://qdrant_node_1:6335")
33+
.with_exposed_ports(QDRANT_GRPC_PORT)
34+
)
2435

2536

2637
# Reference: https://gist.github.com/dizzythinks/f3bb37fd8ab1484bfec79d39ad8a92d3
@@ -38,6 +49,7 @@ def get_pom_version():
3849
@pytest.fixture(scope="module", autouse=True)
3950
def setup_container(request):
4051
qdrant_container.start()
52+
wait_for_logs(qdrant_container, "Qdrant gRPC listening on 6334")
4153

4254
def remove_container():
4355
qdrant_container.stop()
@@ -92,14 +104,20 @@ def qdrant():
92104
"multi": models.VectorParams(
93105
size=QDRANT_EMBEDDING_DIM,
94106
distance=QDRANT_DISTANCE,
95-
multivector_config=models.MultiVectorConfig(comparator=models.MultiVectorComparator.MAX_SIM)
96-
)
107+
multivector_config=models.MultiVectorConfig(
108+
comparator=models.MultiVectorComparator.MAX_SIM
109+
),
110+
),
97111
},
98112
sparse_vectors_config={
99113
"sparse": models.SparseVectorParams(),
100114
"another_sparse": models.SparseVectorParams(),
101115
},
116+
sharding_method=models.ShardingMethod.CUSTOM,
102117
)
118+
119+
client.create_shard_key(collection_name, STRING_SHARD_KEY)
120+
client.create_shard_key(collection_name, INTEGER_SHARD_KEY)
103121

104122
yield Qdrant(
105123
url=f"http://{host}:{grpc_port}",

src/test/python/test_qdrant_ingest.py

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pyspark.sql import SparkSession
33

44
from .schema import schema
5-
from .conftest import Qdrant
5+
from .conftest import Qdrant, STRING_SHARD_KEY, INTEGER_SHARD_KEY
66

77
current_directory = os.path.dirname(__file__)
88
input_file_path = os.path.join(current_directory, "..", "resources", "users.json")
@@ -20,12 +20,16 @@ def test_upsert_unnamed_vectors(qdrant: Qdrant, spark_session: SparkSession):
2020
"embedding_field": "dense_vector",
2121
"api_key": qdrant.api_key,
2222
"schema": df.schema.json(),
23+
"shard_key_selector": STRING_SHARD_KEY,
2324
}
2425

2526
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
2627

2728
assert (
28-
qdrant.client.count(qdrant.collection_name).count == df.count()
29+
qdrant.client.count(
30+
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
31+
).count
32+
== df.count()
2933
), "Uploaded points count is not equal to the dataframe count"
3034

3135

@@ -41,12 +45,16 @@ def test_upsert_named_vectors(qdrant: Qdrant, spark_session: SparkSession):
4145
"vector_name": "dense",
4246
"schema": df.schema.json(),
4347
"api_key": qdrant.api_key,
48+
"shard_key_selector": STRING_SHARD_KEY,
4449
}
4550

4651
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
4752

4853
assert (
49-
qdrant.client.count(qdrant.collection_name).count == df.count()
54+
qdrant.client.count(
55+
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
56+
).count
57+
== df.count()
5058
), "Uploaded points count is not equal to the dataframe count"
5159

5260

@@ -65,12 +73,16 @@ def test_upsert_multiple_named_dense_vectors(
6573
"vector_names": "dense,another_dense",
6674
"schema": df.schema.json(),
6775
"api_key": qdrant.api_key,
76+
"shard_key_selector": STRING_SHARD_KEY,
6877
}
6978

7079
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
7180

7281
assert (
73-
qdrant.client.count(qdrant.collection_name).count == df.count()
82+
qdrant.client.count(
83+
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
84+
).count
85+
== df.count()
7486
), "Uploaded points count is not equal to the dataframe count"
7587

7688

@@ -88,12 +100,16 @@ def test_upsert_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession):
88100
"sparse_vector_names": "sparse",
89101
"schema": df.schema.json(),
90102
"api_key": qdrant.api_key,
103+
"shard_key_selector": STRING_SHARD_KEY,
91104
}
92105

93106
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
94107

95108
assert (
96-
qdrant.client.count(qdrant.collection_name).count == df.count()
109+
qdrant.client.count(
110+
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
111+
).count
112+
== df.count()
97113
), "Uploaded points count is not equal to the dataframe count"
98114

99115

@@ -111,12 +127,16 @@ def test_upsert_multiple_sparse_vectors(qdrant: Qdrant, spark_session: SparkSess
111127
"sparse_vector_names": "sparse,another_sparse",
112128
"schema": df.schema.json(),
113129
"api_key": qdrant.api_key,
130+
"shard_key_selector": STRING_SHARD_KEY,
114131
}
115132

116133
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
117134

118135
assert (
119-
qdrant.client.count(qdrant.collection_name).count == df.count()
136+
qdrant.client.count(
137+
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
138+
).count
139+
== df.count()
120140
), "Uploaded points count is not equal to the dataframe count"
121141

122142

@@ -136,12 +156,16 @@ def test_upsert_sparse_named_dense_vectors(qdrant: Qdrant, spark_session: SparkS
136156
"sparse_vector_names": "sparse",
137157
"schema": df.schema.json(),
138158
"api_key": qdrant.api_key,
159+
"shard_key_selector": STRING_SHARD_KEY,
139160
}
140161

141162
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
142163

143164
assert (
144-
qdrant.client.count(qdrant.collection_name).count == df.count()
165+
qdrant.client.count(
166+
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
167+
).count
168+
== df.count()
145169
), "Uploaded points count is not equal to the dataframe count"
146170

147171

@@ -162,12 +186,16 @@ def test_upsert_sparse_unnamed_dense_vectors(
162186
"sparse_vector_names": "sparse",
163187
"schema": df.schema.json(),
164188
"api_key": qdrant.api_key,
189+
"shard_key_selector": INTEGER_SHARD_KEY,
165190
}
166191

167192
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
168193

169194
assert (
170-
qdrant.client.count(qdrant.collection_name).count == df.count()
195+
qdrant.client.count(
196+
qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY
197+
).count
198+
== df.count()
171199
), "Uploaded points count is not equal to the dataframe count"
172200

173201

@@ -189,17 +217,20 @@ def test_upsert_multiple_sparse_dense_vectors(
189217
"sparse_vector_names": "sparse,another_sparse",
190218
"schema": df.schema.json(),
191219
"api_key": qdrant.api_key,
220+
"shard_key_selector": INTEGER_SHARD_KEY,
192221
}
193222

194223
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
195224

196225
assert (
197-
qdrant.client.count(qdrant.collection_name).count == df.count()
226+
qdrant.client.count(
227+
qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY
228+
).count
229+
== df.count()
198230
), "Uploaded points count is not equal to the dataframe count"
199231

200-
def test_upsert_multi_vector(
201-
qdrant: Qdrant, spark_session: SparkSession
202-
):
232+
233+
def test_upsert_multi_vector(qdrant: Qdrant, spark_session: SparkSession):
203234
df = (
204235
spark_session.read.schema(schema)
205236
.option("multiline", "true")
@@ -212,12 +243,16 @@ def test_upsert_multi_vector(
212243
"multi_vector_names": "multi",
213244
"schema": df.schema.json(),
214245
"api_key": qdrant.api_key,
246+
"shard_key_selector": INTEGER_SHARD_KEY,
215247
}
216248

217249
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
218250

219251
assert (
220-
qdrant.client.count(qdrant.collection_name).count == df.count()
252+
qdrant.client.count(
253+
qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY
254+
).count
255+
== df.count()
221256
), "Uploaded points count is not equal to the dataframe count"
222257

223258

@@ -233,11 +268,15 @@ def test_upsert_without_vectors(qdrant: Qdrant, spark_session: SparkSession):
233268
"collection_name": qdrant.collection_name,
234269
"schema": df.schema.json(),
235270
"api_key": qdrant.api_key,
271+
"shard_key_selector": INTEGER_SHARD_KEY,
236272
}
237273
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
238274

239275
assert (
240-
qdrant.client.count(qdrant.collection_name).count == df.count()
276+
qdrant.client.count(
277+
qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY
278+
).count
279+
== df.count()
241280
), "Uploaded points count is not equal to the dataframe count"
242281

243282

@@ -256,7 +295,27 @@ def test_custom_id_field(qdrant: Qdrant, spark_session: SparkSession):
256295
"id_field": "id",
257296
"schema": df.schema.json(),
258297
"api_key": qdrant.api_key,
298+
"shard_key_selector": f"{STRING_SHARD_KEY},{INTEGER_SHARD_KEY}",
259299
}
260300
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()
261301

262-
assert len(qdrant.client.retrieve(qdrant.collection_name, [1, 2, 3, 15, 18])) == 5
302+
assert (
303+
len(
304+
qdrant.client.retrieve(
305+
qdrant.collection_name,
306+
[1, 2, 3, 15, 18],
307+
shard_key_selector=INTEGER_SHARD_KEY,
308+
)
309+
)
310+
== 5
311+
)
312+
assert (
313+
len(
314+
qdrant.client.retrieve(
315+
qdrant.collection_name,
316+
[1, 2, 3, 15, 18],
317+
shard_key_selector=STRING_SHARD_KEY,
318+
)
319+
)
320+
== 5
321+
)

0 commit comments

Comments
 (0)