22from pyspark .sql import SparkSession
33
44from .schema import schema
5- from .conftest import Qdrant
5+ from .conftest import Qdrant , STRING_SHARD_KEY , INTEGER_SHARD_KEY
66
77current_directory = os .path .dirname (__file__ )
88input_file_path = os .path .join (current_directory , ".." , "resources" , "users.json" )
@@ -20,12 +20,16 @@ def test_upsert_unnamed_vectors(qdrant: Qdrant, spark_session: SparkSession):
2020 "embedding_field" : "dense_vector" ,
2121 "api_key" : qdrant .api_key ,
2222 "schema" : df .schema .json (),
23+ "shard_key_selector" : STRING_SHARD_KEY ,
2324 }
2425
2526 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
2627
2728 assert (
28- qdrant .client .count (qdrant .collection_name ).count == df .count ()
29+ qdrant .client .count (
30+ qdrant .collection_name , shard_key_selector = STRING_SHARD_KEY
31+ ).count
32+ == df .count ()
2933 ), "Uploaded points count is not equal to the dataframe count"
3034
3135
@@ -41,12 +45,16 @@ def test_upsert_named_vectors(qdrant: Qdrant, spark_session: SparkSession):
4145 "vector_name" : "dense" ,
4246 "schema" : df .schema .json (),
4347 "api_key" : qdrant .api_key ,
48+ "shard_key_selector" : STRING_SHARD_KEY ,
4449 }
4550
4651 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
4752
4853 assert (
49- qdrant .client .count (qdrant .collection_name ).count == df .count ()
54+ qdrant .client .count (
55+ qdrant .collection_name , shard_key_selector = STRING_SHARD_KEY
56+ ).count
57+ == df .count ()
5058 ), "Uploaded points count is not equal to the dataframe count"
5159
5260
@@ -65,12 +73,16 @@ def test_upsert_multiple_named_dense_vectors(
6573 "vector_names" : "dense,another_dense" ,
6674 "schema" : df .schema .json (),
6775 "api_key" : qdrant .api_key ,
76+ "shard_key_selector" : STRING_SHARD_KEY ,
6877 }
6978
7079 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
7180
7281 assert (
73- qdrant .client .count (qdrant .collection_name ).count == df .count ()
82+ qdrant .client .count (
83+ qdrant .collection_name , shard_key_selector = STRING_SHARD_KEY
84+ ).count
85+ == df .count ()
7486 ), "Uploaded points count is not equal to the dataframe count"
7587
7688
@@ -88,12 +100,16 @@ def test_upsert_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession):
88100 "sparse_vector_names" : "sparse" ,
89101 "schema" : df .schema .json (),
90102 "api_key" : qdrant .api_key ,
103+ "shard_key_selector" : STRING_SHARD_KEY ,
91104 }
92105
93106 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
94107
95108 assert (
96- qdrant .client .count (qdrant .collection_name ).count == df .count ()
109+ qdrant .client .count (
110+ qdrant .collection_name , shard_key_selector = STRING_SHARD_KEY
111+ ).count
112+ == df .count ()
97113 ), "Uploaded points count is not equal to the dataframe count"
98114
99115
@@ -111,12 +127,16 @@ def test_upsert_multiple_sparse_vectors(qdrant: Qdrant, spark_session: SparkSess
111127 "sparse_vector_names" : "sparse,another_sparse" ,
112128 "schema" : df .schema .json (),
113129 "api_key" : qdrant .api_key ,
130+ "shard_key_selector" : STRING_SHARD_KEY ,
114131 }
115132
116133 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
117134
118135 assert (
119- qdrant .client .count (qdrant .collection_name ).count == df .count ()
136+ qdrant .client .count (
137+ qdrant .collection_name , shard_key_selector = STRING_SHARD_KEY
138+ ).count
139+ == df .count ()
120140 ), "Uploaded points count is not equal to the dataframe count"
121141
122142
@@ -136,12 +156,16 @@ def test_upsert_sparse_named_dense_vectors(qdrant: Qdrant, spark_session: SparkS
136156 "sparse_vector_names" : "sparse" ,
137157 "schema" : df .schema .json (),
138158 "api_key" : qdrant .api_key ,
159+ "shard_key_selector" : STRING_SHARD_KEY ,
139160 }
140161
141162 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
142163
143164 assert (
144- qdrant .client .count (qdrant .collection_name ).count == df .count ()
165+ qdrant .client .count (
166+ qdrant .collection_name , shard_key_selector = STRING_SHARD_KEY
167+ ).count
168+ == df .count ()
145169 ), "Uploaded points count is not equal to the dataframe count"
146170
147171
@@ -162,12 +186,16 @@ def test_upsert_sparse_unnamed_dense_vectors(
162186 "sparse_vector_names" : "sparse" ,
163187 "schema" : df .schema .json (),
164188 "api_key" : qdrant .api_key ,
189+ "shard_key_selector" : INTEGER_SHARD_KEY ,
165190 }
166191
167192 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
168193
169194 assert (
170- qdrant .client .count (qdrant .collection_name ).count == df .count ()
195+ qdrant .client .count (
196+ qdrant .collection_name , shard_key_selector = INTEGER_SHARD_KEY
197+ ).count
198+ == df .count ()
171199 ), "Uploaded points count is not equal to the dataframe count"
172200
173201
@@ -189,17 +217,20 @@ def test_upsert_multiple_sparse_dense_vectors(
189217 "sparse_vector_names" : "sparse,another_sparse" ,
190218 "schema" : df .schema .json (),
191219 "api_key" : qdrant .api_key ,
220+ "shard_key_selector" : INTEGER_SHARD_KEY ,
192221 }
193222
194223 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
195224
196225 assert (
197- qdrant .client .count (qdrant .collection_name ).count == df .count ()
226+ qdrant .client .count (
227+ qdrant .collection_name , shard_key_selector = INTEGER_SHARD_KEY
228+ ).count
229+ == df .count ()
198230 ), "Uploaded points count is not equal to the dataframe count"
199231
200- def test_upsert_multi_vector (
201- qdrant : Qdrant , spark_session : SparkSession
202- ):
232+
233+ def test_upsert_multi_vector (qdrant : Qdrant , spark_session : SparkSession ):
203234 df = (
204235 spark_session .read .schema (schema )
205236 .option ("multiline" , "true" )
@@ -212,12 +243,16 @@ def test_upsert_multi_vector(
212243 "multi_vector_names" : "multi" ,
213244 "schema" : df .schema .json (),
214245 "api_key" : qdrant .api_key ,
246+ "shard_key_selector" : INTEGER_SHARD_KEY ,
215247 }
216248
217249 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
218250
219251 assert (
220- qdrant .client .count (qdrant .collection_name ).count == df .count ()
252+ qdrant .client .count (
253+ qdrant .collection_name , shard_key_selector = INTEGER_SHARD_KEY
254+ ).count
255+ == df .count ()
221256 ), "Uploaded points count is not equal to the dataframe count"
222257
223258
@@ -233,11 +268,15 @@ def test_upsert_without_vectors(qdrant: Qdrant, spark_session: SparkSession):
233268 "collection_name" : qdrant .collection_name ,
234269 "schema" : df .schema .json (),
235270 "api_key" : qdrant .api_key ,
271+ "shard_key_selector" : INTEGER_SHARD_KEY ,
236272 }
237273 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
238274
239275 assert (
240- qdrant .client .count (qdrant .collection_name ).count == df .count ()
276+ qdrant .client .count (
277+ qdrant .collection_name , shard_key_selector = INTEGER_SHARD_KEY
278+ ).count
279+ == df .count ()
241280 ), "Uploaded points count is not equal to the dataframe count"
242281
243282
@@ -256,7 +295,27 @@ def test_custom_id_field(qdrant: Qdrant, spark_session: SparkSession):
256295 "id_field" : "id" ,
257296 "schema" : df .schema .json (),
258297 "api_key" : qdrant .api_key ,
298+ "shard_key_selector" : f"{ STRING_SHARD_KEY } ,{ INTEGER_SHARD_KEY } " ,
259299 }
260300 df .write .format ("io.qdrant.spark.Qdrant" ).options (** opts ).mode ("append" ).save ()
261301
262- assert len (qdrant .client .retrieve (qdrant .collection_name , [1 , 2 , 3 , 15 , 18 ])) == 5
302+ assert (
303+ len (
304+ qdrant .client .retrieve (
305+ qdrant .collection_name ,
306+ [1 , 2 , 3 , 15 , 18 ],
307+ shard_key_selector = INTEGER_SHARD_KEY ,
308+ )
309+ )
310+ == 5
311+ )
312+ assert (
313+ len (
314+ qdrant .client .retrieve (
315+ qdrant .collection_name ,
316+ [1 , 2 , 3 , 15 , 18 ],
317+ shard_key_selector = STRING_SHARD_KEY ,
318+ )
319+ )
320+ == 5
321+ )
0 commit comments