Skip to content

Commit a9b8983

Browse files
authored
chore: Improved Input Validation (#44)
* chore: Misc. code improv Signed-off-by: Anush008 <anushshetty90@gmail.com> * chore: Bump version Signed-off-by: Anush008 <anushshetty90@gmail.com> --------- Signed-off-by: Anush008 <anushshetty90@gmail.com>
1 parent b8fe403 commit a9b8983

File tree

6 files changed

+239
-28
lines changed

6 files changed

+239
-28
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<modelVersion>4.0.0</modelVersion>
77
<groupId>io.qdrant</groupId>
88
<artifactId>spark</artifactId>
9-
<version>2.3.4</version>
9+
<version>2.3.5</version>
1010
<name>qdrant-spark</name>
1111
<url>https://github.com/qdrant/qdrant-spark</url>
1212
<description>An Apache Spark connector for the Qdrant vector database</description>

src/main/java/io/qdrant/spark/QdrantDataWriter.java

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,34 @@ private void writeBatch(int retries) {
5151

5252
try {
5353
doWriteBatch();
54-
pointsBuffer.clear();
5554
} catch (Exception e) {
56-
LOG.error("Exception while uploading batch to Qdrant: {}", e.getMessage());
55+
LOG.error("Exception while uploading batch to Qdrant", e);
5756
if (retries > 0) {
58-
LOG.info("Retrying upload batch to Qdrant");
57+
LOG.info("Retrying upload batch to Qdrant (retries remaining: {})", retries);
5958
writeBatch(retries - 1);
59+
return;
6060
} else {
61-
throw new RuntimeException(e);
61+
pointsBuffer.clear();
62+
throw new RuntimeException(
63+
"Failed to write batch after " + (options.retries + 1) + " attempts", e);
6264
}
6365
}
66+
pointsBuffer.clear();
6467
}
6568

6669
private void doWriteBatch() throws Exception {
6770
LOG.info("Uploading batch of {} points to Qdrant", pointsBuffer.size());
6871

6972
// Instantiate QdrantGrpc client for each batch to maintain serializability
70-
QdrantGrpc qdrant = new QdrantGrpc(new URL(options.qdrantUrl), options.apiKey);
71-
qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector, options.wait);
72-
qdrant.close();
73+
QdrantGrpc qdrant = null;
74+
try {
75+
qdrant = new QdrantGrpc(new URL(options.qdrantUrl), options.apiKey);
76+
qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector, options.wait);
77+
} finally {
78+
if (qdrant != null) {
79+
qdrant.close();
80+
}
81+
}
7382
}
7483

7584
@Override

src/main/java/io/qdrant/spark/QdrantGrpc.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import io.qdrant.client.grpc.Points.ShardKeySelector;
77
import io.qdrant.client.grpc.Points.UpsertPoints;
88
import java.io.Serializable;
9-
import java.net.MalformedURLException;
109
import java.net.URL;
1110
import java.util.List;
1211
import java.util.concurrent.ExecutionException;
@@ -16,8 +15,14 @@ public class QdrantGrpc implements Serializable {
1615

1716
private final QdrantClient client;
1817

19-
public QdrantGrpc(URL url, String apiKey) throws MalformedURLException {
18+
public QdrantGrpc(URL url, String apiKey) {
19+
if (url == null) {
20+
throw new IllegalArgumentException("URL cannot be null");
21+
}
2022
String host = url.getHost();
23+
if (host == null || host.isEmpty()) {
24+
throw new IllegalArgumentException("Invalid URL: host is missing. Provided URL: " + url);
25+
}
2126
int port = url.getPort() == -1 ? 6334 : url.getPort();
2227
boolean useTls = url.getProtocol().equalsIgnoreCase("https");
2328
client =

src/main/java/io/qdrant/spark/QdrantOptions.java

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,25 @@ public QdrantOptions(Map<String, String> options) {
4040
Objects.requireNonNull(options);
4141

4242
qdrantUrl = options.get("qdrant_url");
43+
if (qdrantUrl == null || qdrantUrl.isEmpty()) {
44+
throw new IllegalArgumentException("qdrant_url option is required and cannot be empty");
45+
}
46+
4347
collectionName = options.get("collection_name");
48+
if (collectionName == null || collectionName.isEmpty()) {
49+
throw new IllegalArgumentException("collection_name option is required and cannot be empty");
50+
}
51+
4452
batchSize = getIntOption(options, "batch_size", DEFAULT_BATCH_SIZE);
53+
if (batchSize <= 0) {
54+
throw new IllegalArgumentException("batch_size must be positive, got: " + batchSize);
55+
}
56+
4557
retries = getIntOption(options, "retries", DEFAULT_RETRIES);
58+
if (retries < 0) {
59+
throw new IllegalArgumentException("retries cannot be negative, got: " + retries);
60+
}
61+
4662
idField = options.getOrDefault("id_field", "");
4763
apiKey = options.getOrDefault("api_key", "");
4864
embeddingField = options.getOrDefault("embedding_field", "");
@@ -61,22 +77,33 @@ public QdrantOptions(Map<String, String> options) {
6177

6278
validateSparseVectorFields();
6379
validateVectorFields();
80+
validateMultiVectorFields();
6481

6582
payloadFieldsToSkip = buildPayloadFieldsToSkip();
6683
}
6784

6885
private int getIntOption(Map<String, String> options, String key, int defaultValue) {
69-
return Integer.parseInt(options.getOrDefault(key, String.valueOf(defaultValue)));
86+
String value = options.getOrDefault(key, String.valueOf(defaultValue));
87+
try {
88+
return Integer.parseInt(value);
89+
} catch (NumberFormatException e) {
90+
throw new IllegalArgumentException(
91+
"Invalid value for option '" + key + "': '" + value + "'. Expected an integer.", e);
92+
}
7093
}
7194

7295
private boolean getBooleanOption(Map<String, String> options, String key, boolean defaultValue) {
7396
return Boolean.parseBoolean(options.getOrDefault(key, String.valueOf(defaultValue)));
7497
}
7598

7699
private String[] parseArray(String input) {
77-
return input == null
78-
? new String[0]
79-
: Arrays.stream(input.split(",")).map(String::trim).toArray(String[]::new);
100+
if (input == null || input.trim().isEmpty()) {
101+
return new String[0];
102+
}
103+
return Arrays.stream(input.split(","))
104+
.map(String::trim)
105+
.filter(s -> !s.isEmpty())
106+
.toArray(String[]::new);
80107
}
81108

82109
private void validateSparseVectorFields() {
@@ -93,6 +120,13 @@ private void validateVectorFields() {
93120
}
94121
}
95122

123+
private void validateMultiVectorFields() {
124+
if (multiVectorFields.length != multiVectorNames.length) {
125+
throw new IllegalArgumentException(
126+
"Multi vector fields and names should have the same length");
127+
}
128+
}
129+
96130
private ShardKeySelector parseShardKeys(@Nullable String shardKeys) {
97131
if (shardKeys == null) {
98132
return null;

src/main/java/io/qdrant/spark/QdrantPointIdHandler.java

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,57 @@ static PointId preparePointId(InternalRow record, StructType schema, QdrantOptio
1616
return id(UUID.randomUUID());
1717
}
1818

19-
int idFieldIndex = schema.fieldIndex(idField);
19+
int idFieldIndex;
20+
try {
21+
idFieldIndex = schema.fieldIndex(idField);
22+
} catch (IllegalArgumentException e) {
23+
throw new IllegalArgumentException(
24+
"Field '"
25+
+ idField
26+
+ "' specified in 'id_field' does not exist in the schema. "
27+
+ "Available fields: "
28+
+ String.join(", ", schema.fieldNames()),
29+
e);
30+
}
31+
2032
DataType idFieldType = schema.fields()[idFieldIndex].dataType();
2133
switch (idFieldType.typeName()) {
2234
case "string":
23-
return id(UUID.fromString(record.getString(idFieldIndex)));
35+
String idString = record.getString(idFieldIndex);
36+
if (idString == null) {
37+
throw new IllegalArgumentException(
38+
"The 'id_field' contains a null value. IDs cannot be null. "
39+
+ "Either provide valid ID values or remove 'id_field' option to use"
40+
+ " auto-generated UUIDs.");
41+
}
42+
try {
43+
return id(UUID.fromString(idString));
44+
} catch (IllegalArgumentException e) {
45+
throw new IllegalArgumentException(
46+
"The 'id_field' value '"
47+
+ idString
48+
+ "' is not a valid UUID. "
49+
+ "String IDs must be in UUID format (e.g.,"
50+
+ " '550e8400-e29b-41d4-a716-446655440000'). "
51+
+ "For non-UUID string IDs, consider using integer IDs or hashing your strings to"
52+
+ " UUIDs.",
53+
e);
54+
}
2455

2556
case "integer":
26-
case "long":
2757
return id(record.getInt(idFieldIndex));
2858

59+
case "long":
60+
return id(record.getLong(idFieldIndex));
61+
2962
default:
30-
throw new IllegalArgumentException("Point ID should be of type string or integer");
63+
throw new IllegalArgumentException(
64+
"Point ID field '"
65+
+ idField
66+
+ "' has unsupported type '"
67+
+ idFieldType.typeName()
68+
+ "'. "
69+
+ "Supported types: string (UUID format), integer, long");
3170
}
3271
}
3372
}

0 commit comments

Comments
 (0)