Skip to content

Commit 53d1d2e

Browse files
committed
refactor: Declarative datawriter, options implementation (#23)
1 parent a7ebd41 commit 53d1d2e

13 files changed

Lines changed: 160 additions & 250 deletions

README.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,31 @@
22

33
[Apache Spark](https://spark.apache.org/) is a distributed computing framework designed for big data processing and analytics. This connector enables [Qdrant](https://qdrant.tech/) to be a storage destination in Spark.
44

5-
## Installation 🚀
5+
## Installation
66

77
> [!IMPORTANT]
88
> Requires Java 8 or above.
99
10-
### GitHub Releases 📦
10+
### GitHub Releases
1111

1212
The packaged `jar` file can be found [here](https://github.com/qdrant/qdrant-spark/releases).
1313

14-
### Building from source 🛠️
14+
### Building from source
1515

1616
To build the `jar` from source, you need [JDK@8](https://www.azul.com/downloads/#zulu) and [Maven](https://maven.apache.org/) installed.
17-
Once the requirements have been satisfied, run the following command in the project root. 🛠️
17+
Once the requirements have been satisfied, run the following command in the project root.
1818

1919
```bash
2020
mvn package
2121
```
2222

2323
This will build and store the fat JAR in the `target` directory by default.
2424

25-
### Maven Central 📚
25+
### Maven Central
2626

2727
For use with Java and Scala projects, the package can be found [here](https://central.sonatype.com/artifact/io.qdrant/spark).
2828

29-
## Usage 📝
29+
## Usage
3030

3131
### Creating a Spark session (Single-node) with Qdrant support
3232

@@ -42,7 +42,7 @@ spark = SparkSession.builder.config(
4242
.getOrCreate()
4343
```
4444

45-
### Loading data 📊
45+
### Loading data
4646

4747
> [!IMPORTANT]
4848
> Before loading the data using this connector, a collection has to be [created](https://qdrant.tech/documentation/concepts/collections/#create-a-collection) in advance with the appropriate vector dimensions and configurations.
@@ -191,11 +191,11 @@ You can use the connector as a library in Databricks to ingest data into Qdrant.
191191

192192
<img width="1064" alt="Screenshot 2024-01-05 at 17 20 01 (1)" src="https://github.com/qdrant/qdrant-spark/assets/46051506/d95773e0-c5c6-4ff2-bf50-8055bb08fd1b">
193193

194-
## Datatype support 📋
194+
## Datatype support
195195

196-
Qdrant supports all the Spark data types. The appropriate types are mapped based on the provided `schema`.
196+
The appropriate Spark data types are mapped to the Qdrant payload based on the provided `schema`.
197197

198-
## Options and Spark types 🛠️
198+
## Options and Spark types
199199

200200
| Option | Description | Column DataType | Required |
201201
| :--------------------------- | :------------------------------------------------------------------ | :---------------------------- | :------- |
@@ -215,6 +215,6 @@ Qdrant supports all the Spark data types. The appropriate types are mapped based
215215
| `sparse_vector_names` | Comma-separated names of the sparse vectors in the collection. | - ||
216216
| `shard_key_selector` | Comma-separated names of custom shard keys to use during upsert. | - ||
217217

218-
## LICENSE 📜
218+
## LICENSE
219219

220220
Apache 2.0 © [2024](https://github.com/qdrant/qdrant-spark/blob/master/LICENSE)

pom.xml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,6 @@
190190
<configuration>
191191
<includeStale>false</includeStale>
192192
<style>GOOGLE</style>
193-
<formatMain>true</formatMain>
194-
<formatTest>true</formatTest>
195193
<filterModified>false</filterModified>
196194
<skip>false</skip>
197195
<fixImports>true</fixImports>

src/main/java/io/qdrant/spark/Qdrant.java

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,40 @@
88
import org.apache.spark.sql.types.StructType;
99
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
1010

11-
/** A class that implements the TableProvider and DataSourceRegister interfaces. */
11+
/** Qdrant datasource for Apache Spark. */
1212
public class Qdrant implements TableProvider, DataSourceRegister {
1313

14-
private final String[] requiredFields = new String[] {"schema", "collection_name", "qdrant_url"};
14+
private static final String[] REQUIRED_FIELDS = {"schema", "collection_name", "qdrant_url"};
1515

16-
/**
17-
* Returns the short name of the data source.
18-
*
19-
* @return The short name of the data source.
20-
*/
16+
/** Returns the short name of the data source. */
2117
@Override
2218
public String shortName() {
2319
return "qdrant";
2420
}
2521

2622
/**
27-
* Infers the schema of the data source based on the provided options.
23+
* Validates and infers the schema from the provided options.
2824
*
29-
* @param options The options used to infer the schema.
30-
* @return The inferred schema.
25+
* @throws IllegalArgumentException if required options are missing.
3126
*/
3227
@Override
3328
public StructType inferSchema(CaseInsensitiveStringMap options) {
34-
for (String fieldName : requiredFields) {
35-
if (!options.containsKey(fieldName)) {
36-
throw new IllegalArgumentException(fieldName.concat(" option is required"));
29+
validateOptions(options);
30+
return (StructType) StructType.fromJson(options.get("schema"));
31+
}
32+
33+
private void validateOptions(CaseInsensitiveStringMap options) {
34+
for (String field : REQUIRED_FIELDS) {
35+
if (!options.containsKey(field)) {
36+
throw new IllegalArgumentException(String.format("%s option is required", field));
3737
}
3838
}
39-
StructType schema = (StructType) StructType.fromJson(options.get("schema"));
40-
41-
return schema;
4239
}
4340

4441
/**
45-
* Returns a table for the data source based on the provided schema, partitioning, and properties.
42+
* Creates a Qdrant table instance with validated options.
4643
*
47-
* @param schema The schema of the table.
48-
* @param partitioning The partitioning of the table.
49-
* @param properties The properties of the table.
50-
* @return The table for the data source.
44+
* @throws IllegalArgumentException if options are invalid.
5145
*/
5246
@Override
5347
public Table getTable(

src/main/java/io/qdrant/spark/QdrantBatchWriter.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import org.apache.spark.sql.connector.write.WriterCommitMessage;
77
import org.apache.spark.sql.types.StructType;
88

9-
/** QdrantBatchWriter class implements the BatchWrite interface. */
9+
/** Qdrant batch writer for Apache Spark. */
1010
public class QdrantBatchWriter implements BatchWrite {
1111

1212
private final QdrantOptions options;
@@ -23,13 +23,8 @@ public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) {
2323
}
2424

2525
@Override
26-
public void commit(WriterCommitMessage[] messages) {
27-
// TODO Auto-generated method stub
28-
29-
}
26+
public void commit(WriterCommitMessage[] messages) {}
3027

3128
@Override
32-
public void abort(WriterCommitMessage[] messages) {
33-
// TODO Auto-generated method stub
34-
}
29+
public void abort(WriterCommitMessage[] messages) {}
3530
}
Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
package io.qdrant.spark;
22

3-
import java.util.Arrays;
43
import java.util.Collections;
5-
import java.util.HashSet;
4+
import java.util.EnumSet;
65
import java.util.Set;
76
import org.apache.spark.sql.connector.catalog.SupportsWrite;
87
import org.apache.spark.sql.connector.catalog.TableCapability;
98
import org.apache.spark.sql.connector.write.LogicalWriteInfo;
109
import org.apache.spark.sql.connector.write.WriteBuilder;
1110
import org.apache.spark.sql.types.StructType;
1211

13-
/** QdrantCluster class implements the SupportsWrite interface. */
12+
/** Qdrant cluster implementation supporting batch writes. */
1413
public class QdrantCluster implements SupportsWrite {
1514

1615
private final StructType schema;
1716
private final QdrantOptions options;
1817

19-
private static final Set<TableCapability> TABLE_CAPABILITY_SET =
20-
Collections.unmodifiableSet(
21-
new HashSet<>(
22-
Arrays.asList(TableCapability.BATCH_WRITE, TableCapability.STREAMING_WRITE)));
18+
private static final Set<TableCapability> CAPABILITIES = EnumSet.of(TableCapability.BATCH_WRITE);
2319

2420
public QdrantCluster(QdrantOptions options, StructType schema) {
2521
this.options = options;
@@ -28,7 +24,7 @@ public QdrantCluster(QdrantOptions options, StructType schema) {
2824

2925
@Override
3026
public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
31-
return new QdrantWriteBuilder(this.options, this.schema);
27+
return new QdrantWriteBuilder(options, schema);
3228
}
3329

3430
@Override
@@ -38,11 +34,11 @@ public String name() {
3834

3935
@Override
4036
public StructType schema() {
41-
return this.schema;
37+
return schema;
4238
}
4339

4440
@Override
4541
public Set<TableCapability> capabilities() {
46-
return TABLE_CAPABILITY_SET;
42+
return Collections.unmodifiableSet(CAPABILITIES);
4743
}
4844
}

src/main/java/io/qdrant/spark/QdrantDataWriter.java

Lines changed: 43 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,88 @@
11
package io.qdrant.spark;
22

3-
import io.qdrant.client.grpc.JsonWithInt.Value;
4-
import io.qdrant.client.grpc.Points.PointId;
53
import io.qdrant.client.grpc.Points.PointStruct;
6-
import io.qdrant.client.grpc.Points.Vectors;
74
import java.io.Serializable;
85
import java.net.URL;
96
import java.util.ArrayList;
10-
import java.util.Map;
7+
import java.util.List;
118
import org.apache.spark.sql.catalyst.InternalRow;
129
import org.apache.spark.sql.connector.write.DataWriter;
1310
import org.apache.spark.sql.connector.write.WriterCommitMessage;
1411
import org.apache.spark.sql.types.StructType;
1512
import org.slf4j.Logger;
1613
import org.slf4j.LoggerFactory;
1714

18-
/** A DataWriter implementation that writes data to Qdrant. */
15+
/** DataWriter implementation for writing data to Qdrant. */
1916
public class QdrantDataWriter implements DataWriter<InternalRow>, Serializable {
17+
18+
private static final Logger LOG = LoggerFactory.getLogger(QdrantDataWriter.class);
19+
2020
private final QdrantOptions options;
2121
private final StructType schema;
22-
private final String qdrantUrl;
23-
private final String apiKey;
24-
private final Logger LOG = LoggerFactory.getLogger(QdrantDataWriter.class);
25-
26-
private final ArrayList<PointStruct> points = new ArrayList<>();
22+
private final List<PointStruct> pointsBuffer = new ArrayList<>();
2723

2824
public QdrantDataWriter(QdrantOptions options, StructType schema) {
2925
this.options = options;
3026
this.schema = schema;
31-
this.qdrantUrl = options.qdrantUrl;
32-
this.apiKey = options.apiKey;
3327
}
3428

3529
@Override
3630
public void write(InternalRow record) {
37-
PointStruct.Builder pointBuilder = PointStruct.newBuilder();
38-
39-
PointId pointId = QdrantPointIdHandler.preparePointId(record, this.schema, this.options);
40-
pointBuilder.setId(pointId);
41-
42-
Vectors vectors = QdrantVectorHandler.prepareVectors(record, this.schema, this.options);
43-
pointBuilder.setVectors(vectors);
44-
45-
Map<String, Value> payload =
46-
QdrantPayloadHandler.preparePayload(record, this.schema, this.options);
47-
pointBuilder.putAllPayload(payload);
48-
49-
this.points.add(pointBuilder.build());
31+
PointStruct point = createPointStruct(record);
32+
pointsBuffer.add(point);
5033

51-
if (this.points.size() >= this.options.batchSize) {
52-
this.write(this.options.retries);
34+
if (pointsBuffer.size() >= options.batchSize) {
35+
writeBatch(options.retries);
5336
}
5437
}
5538

56-
@Override
57-
public WriterCommitMessage commit() {
58-
this.write(this.options.retries);
59-
return new WriterCommitMessage() {
60-
@Override
61-
public String toString() {
62-
return "point committed to Qdrant";
63-
}
64-
};
39+
private PointStruct createPointStruct(InternalRow record) {
40+
PointStruct.Builder pointBuilder = PointStruct.newBuilder();
41+
pointBuilder.setId(QdrantPointIdHandler.preparePointId(record, schema, options));
42+
pointBuilder.setVectors(QdrantVectorHandler.prepareVectors(record, schema, options));
43+
pointBuilder.putAllPayload(QdrantPayloadHandler.preparePayload(record, schema, options));
44+
return pointBuilder.build();
6545
}
6646

67-
public void write(int retries) {
68-
LOG.info(
69-
String.join(
70-
"", "Uploading batch of ", Integer.toString(this.points.size()), " points to Qdrant"));
71-
72-
if (this.points.isEmpty()) {
47+
private void writeBatch(int retries) {
48+
if (pointsBuffer.isEmpty()) {
7349
return;
7450
}
51+
7552
try {
76-
// Instantiate a new QdrantGrpc object to maintain serializability
77-
QdrantGrpc qdrant = new QdrantGrpc(new URL(this.qdrantUrl), this.apiKey);
78-
qdrant.upsert(this.options.collectionName, this.points, this.options.shardKeySelector);
79-
qdrant.close();
80-
this.points.clear();
53+
doWriteBatch();
54+
pointsBuffer.clear();
8155
} catch (Exception e) {
82-
LOG.error(String.join("", "Exception while uploading batch to Qdrant: ", e.getMessage()));
56+
LOG.error("Exception while uploading batch to Qdrant: {}", e.getMessage());
8357
if (retries > 0) {
8458
LOG.info("Retrying upload batch to Qdrant");
85-
write(retries - 1);
59+
writeBatch(retries - 1);
8660
} else {
8761
throw new RuntimeException(e);
8862
}
8963
}
9064
}
9165

66+
private void doWriteBatch() throws Exception {
67+
LOG.info("Uploading batch of {} points to Qdrant", pointsBuffer.size());
68+
69+
// Instantiate QdrantGrpc client for each batch to maintain serializability
70+
QdrantGrpc qdrant = new QdrantGrpc(new URL(options.qdrantUrl), options.apiKey);
71+
qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector);
72+
qdrant.close();
73+
}
74+
75+
@Override
76+
public WriterCommitMessage commit() {
77+
writeBatch(options.retries);
78+
return new WriterCommitMessage() {
79+
@Override
80+
public String toString() {
81+
return "point committed to Qdrant";
82+
}
83+
};
84+
}
85+
9286
@Override
9387
public void abort() {}
9488

src/main/java/io/qdrant/spark/QdrantDataWriterFactory.java

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,26 @@
44
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory;
55
import org.apache.spark.sql.types.StructType;
66

7-
/** Factory class for creating QdrantDataWriter instances for Spark Structured Streaming. */
7+
/** Factory class for creating QdrantDataWriter instances for Spark data sources. */
88
public class QdrantDataWriterFactory implements StreamingDataWriterFactory, DataWriterFactory {
9+
910
private final QdrantOptions options;
1011
private final StructType schema;
1112

12-
/**
13-
* Constructor for QdrantDataWriterFactory.
14-
*
15-
* @param options QdrantOptions instance containing configuration options for Qdrant.
16-
* @param schema StructType instance containing schema information for the data being written.
17-
*/
1813
public QdrantDataWriterFactory(QdrantOptions options, StructType schema) {
1914
this.options = options;
2015
this.schema = schema;
2116
}
2217

2318
@Override
2419
public QdrantDataWriter createWriter(int partitionId, long taskId, long epochId) {
25-
try {
26-
return new QdrantDataWriter(this.options, this.schema);
27-
} catch (Exception e) {
28-
throw new RuntimeException(e);
29-
}
20+
return createWriter(partitionId, taskId);
3021
}
3122

3223
@Override
3324
public QdrantDataWriter createWriter(int partitionId, long taskId) {
3425
try {
35-
return new QdrantDataWriter(this.options, this.schema);
26+
return new QdrantDataWriter(options, schema);
3627
} catch (Exception e) {
3728
throw new RuntimeException(e);
3829
}

0 commit comments

Comments
 (0)