1313import java .util .Map ;
1414import org .apache .spark .sql .catalyst .InternalRow ;
1515import org .apache .spark .sql .catalyst .util .ArrayData ;
16+ import org .apache .spark .sql .types .ArrayType ;
1617import org .apache .spark .sql .types .DataType ;
1718import org .apache .spark .sql .types .StructField ;
1819import org .apache .spark .sql .types .StructType ;
@@ -98,6 +99,12 @@ private static float[] extractFloatArray(InternalRow record, int fieldIndex, Dat
9899 throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
99100 }
100101
102+ ArrayType arrayType = (ArrayType ) dataType ;
103+
104+ if (!arrayType .elementType ().typeName ().equalsIgnoreCase ("float" )) {
105+ throw new IllegalArgumentException ("Expected array elements to be of FloatType" );
106+ }
107+
101108 return record .getArray (fieldIndex ).toFloatArray ();
102109 }
103110
@@ -107,14 +114,26 @@ private static int[] extractIntArray(InternalRow record, int fieldIndex, DataTyp
107114 throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
108115 }
109116
117+ ArrayType arrayType = (ArrayType ) dataType ;
118+
119+ if (!arrayType .elementType ().typeName ().equalsIgnoreCase ("integer" )) {
120+ throw new IllegalArgumentException ("Expected array elements to be of IntegerType" );
121+ }
122+
110123 return record .getArray (fieldIndex ).toIntArray ();
111124 }
112125
113126 private static float [][] extractMultiVecArray (
114127 InternalRow record , int fieldIndex , DataType dataType ) {
115128
116129 if (!dataType .typeName ().equalsIgnoreCase ("array" )) {
117- throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
130+ throw new IllegalArgumentException ("Multi Vector field must be of type ArrayType" );
131+ }
132+
133+ ArrayType arrayType = (ArrayType ) dataType ;
134+
135+ if (!arrayType .elementType ().typeName ().equalsIgnoreCase ("array" )) {
136+ throw new IllegalArgumentException ("Multi Vector elements must be of type ArrayType" );
118137 }
119138
120139 ArrayData arrayData = record .getArray (fieldIndex );
0 commit comments