1313import java .util .Map ;
1414import org .apache .spark .sql .catalyst .InternalRow ;
1515import org .apache .spark .sql .catalyst .util .ArrayData ;
16+ import org .apache .spark .sql .types .DataType ;
17+ import org .apache .spark .sql .types .StructField ;
1618import org .apache .spark .sql .types .StructType ;
1719
1820public class QdrantVectorHandler {
1921
2022 public static Vectors prepareVectors (
2123 InternalRow record , StructType schema , QdrantOptions options ) {
2224 Vectors .Builder vectorsBuilder = Vectors .newBuilder ();
23-
2425 // Combine sparse, dense and multi vectors
2526 vectorsBuilder .mergeFrom (prepareSparseVectors (record , schema , options ));
2627 vectorsBuilder .mergeFrom (prepareDenseVectors (record , schema , options ));
2728 vectorsBuilder .mergeFrom (prepareMultiVectors (record , schema , options ));
2829
2930 // Maitaining support for the "embedding_field" and "vector_name" options
3031 if (!options .embeddingField .isEmpty ()) {
31- float [] embeddings = extractFloatArray (record , schema , options .embeddingField );
32+ int fieldIndex = schema .fieldIndex (options .embeddingField );
33+ StructField field = schema .fields ()[fieldIndex ];
34+ float [] embeddings = extractFloatArray (record , fieldIndex , field .dataType ());
3235 // 'options.vectorName' defaults to ""
3336 vectorsBuilder .mergeFrom (
3437 namedVectors (Collections .singletonMap (options .vectorName , vector (embeddings ))));
@@ -42,10 +45,15 @@ private static Vectors prepareSparseVectors(
4245 Map <String , Vector > sparseVectors = new HashMap <>();
4346
4447 for (int i = 0 ; i < options .sparseVectorNames .length ; i ++) {
45- String name = options .sparseVectorNames [i ];
46- float [] values = extractFloatArray ( record , schema , options . sparseVectorValueFields [ i ]) ;
47- int [] indices = extractIntArray (record , schema , options . sparseVectorIndexFields [ i ] );
48+ int fieldIndex = schema . fieldIndex ( options .sparseVectorValueFields [i ]) ;
49+ StructField field = schema . fields ()[ fieldIndex ] ;
50+ float [] values = extractFloatArray (record , fieldIndex , field . dataType () );
4851
52+ fieldIndex = schema .fieldIndex (options .sparseVectorIndexFields [i ]);
53+ field = schema .fields ()[fieldIndex ];
54+ int [] indices = extractIntArray (record , fieldIndex , field .dataType ());
55+
56+ String name = options .sparseVectorNames [i ];
4957 sparseVectors .put (name , vector (Floats .asList (values ), Ints .asList (indices )));
5058 }
5159
@@ -57,8 +65,11 @@ private static Vectors prepareDenseVectors(
5765 Map <String , Vector > denseVectors = new HashMap <>();
5866
5967 for (int i = 0 ; i < options .vectorNames .length ; i ++) {
68+ int fieldIndex = schema .fieldIndex (options .vectorFields [i ]);
69+ StructField field = schema .fields ()[fieldIndex ];
70+ float [] values = extractFloatArray (record , fieldIndex , field .dataType ());
71+
6072 String name = options .vectorNames [i ];
61- float [] values = extractFloatArray (record , schema , options .vectorFields [i ]);
6273 denseVectors .put (name , vector (values ));
6374 }
6475
@@ -70,29 +81,42 @@ private static Vectors prepareMultiVectors(
7081 Map <String , Vector > multiVectors = new HashMap <>();
7182
7283 for (int i = 0 ; i < options .multiVectorNames .length ; i ++) {
73- String name = options .multiVectorNames [i ];
74- float [][] vectors = extractMultiVecArray (record , schema , options .multiVectorFields [i ]);
84+ int fieldIndex = schema .fieldIndex (options .multiVectorFields [i ]);
85+ StructField field = schema .fields ()[fieldIndex ];
86+ float [][] vectors = extractMultiVecArray (record , fieldIndex , field .dataType ());
7587
88+ String name = options .multiVectorNames [i ];
7689 multiVectors .put (name , multiVector (vectors ));
7790 }
7891
7992 return namedVectors (multiVectors );
8093 }
8194
82- private static float [] extractFloatArray (
83- InternalRow record , StructType schema , String fieldName ) {
84- int fieldIndex = schema .fieldIndex (fieldName );
95+ private static float [] extractFloatArray (InternalRow record , int fieldIndex , DataType dataType ) {
96+
97+ if (!dataType .typeName ().equalsIgnoreCase ("array" )) {
98+ throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
99+ }
100+
85101 return record .getArray (fieldIndex ).toFloatArray ();
86102 }
87103
88- private static int [] extractIntArray (InternalRow record , StructType schema , String fieldName ) {
89- int fieldIndex = schema .fieldIndex (fieldName );
104+ private static int [] extractIntArray (InternalRow record , int fieldIndex , DataType dataType ) {
105+
106+ if (!dataType .typeName ().equalsIgnoreCase ("array" )) {
107+ throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
108+ }
109+
90110 return record .getArray (fieldIndex ).toIntArray ();
91111 }
92112
93113 private static float [][] extractMultiVecArray (
94- InternalRow record , StructType schema , String fieldName ) {
95- int fieldIndex = schema .fieldIndex (fieldName );
114+ InternalRow record , int fieldIndex , DataType dataType ) {
115+
116+ if (!dataType .typeName ().equalsIgnoreCase ("array" )) {
117+ throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
118+ }
119+
96120 ArrayData arrayData = record .getArray (fieldIndex );
97121 int numRows = arrayData .numElements ();
98122 ArrayData firstRow = arrayData .getArray (0 );
0 commit comments