diff --git a/lance-spark-3.4_2.12/pom.xml b/lance-spark-3.4_2.12/pom.xml index 7eb567a7..79291b07 100644 --- a/lance-spark-3.4_2.12/pom.xml +++ b/lance-spark-3.4_2.12/pom.xml @@ -25,6 +25,12 @@ ${spark34.version} provided + + org.apache.spark + spark-mllib_${scala.compat.version} + ${spark34.version} + test + org.lance lance-spark-base_${scala.compat.version} diff --git a/lance-spark-3.4_2.13/pom.xml b/lance-spark-3.4_2.13/pom.xml index 2a8195ae..316ee14f 100644 --- a/lance-spark-3.4_2.13/pom.xml +++ b/lance-spark-3.4_2.13/pom.xml @@ -27,6 +27,12 @@ ${spark34.version} provided + + org.apache.spark + spark-mllib_${scala.compat.version} + ${spark34.version} + test + org.lance lance-spark-base_${scala.compat.version} diff --git a/lance-spark-3.5_2.12/pom.xml b/lance-spark-3.5_2.12/pom.xml index 17ed5096..7c295a4d 100644 --- a/lance-spark-3.5_2.12/pom.xml +++ b/lance-spark-3.5_2.12/pom.xml @@ -22,6 +22,12 @@ ${spark35.version} provided + + org.apache.spark + spark-mllib_${scala.compat.version} + ${spark35.version} + test + org.lance lance-spark-base_${scala.compat.version} diff --git a/lance-spark-3.5_2.13/pom.xml b/lance-spark-3.5_2.13/pom.xml index d77e9ed6..348bcf70 100644 --- a/lance-spark-3.5_2.13/pom.xml +++ b/lance-spark-3.5_2.13/pom.xml @@ -26,6 +26,12 @@ ${spark35.version} provided + + org.apache.spark + spark-mllib_${scala.compat.version} + ${spark35.version} + test + org.lance lance-spark-base_${scala.compat.version} diff --git a/lance-spark-4.0_2.13/pom.xml b/lance-spark-4.0_2.13/pom.xml index 76faafe8..7d764659 100644 --- a/lance-spark-4.0_2.13/pom.xml +++ b/lance-spark-4.0_2.13/pom.xml @@ -29,6 +29,12 @@ ${spark40.version} provided + + org.apache.spark + spark-mllib_${scala.compat.version} + ${spark40.version} + test + org.lance lance-spark-base_${scala.compat.version} diff --git a/lance-spark-4.1_2.13/pom.xml b/lance-spark-4.1_2.13/pom.xml index 2a272b66..9f11944f 100644 --- a/lance-spark-4.1_2.13/pom.xml +++ b/lance-spark-4.1_2.13/pom.xml @@ -29,6 +29,12 @@ ${spark41.version} provided + + org.apache.spark + spark-mllib_${scala.compat.version} + ${spark41.version} + test + org.lance lance-spark-base_${scala.compat.version} diff --git a/lance-spark-base_2.12/pom.xml b/lance-spark-base_2.12/pom.xml index e8509456..9d4ce3a9 100644 --- a/lance-spark-base_2.12/pom.xml +++ b/lance-spark-base_2.12/pom.xml @@ -20,6 +20,12 @@ spark-sql_${scala.compat.version} provided + + org.apache.spark + spark-mllib_${scala.compat.version} + ${spark.version} + test + diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala index dfe82981..d0d22174 100644 --- a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala @@ -114,6 +114,8 @@ object LanceArrowWriter { case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) => new IntervalMonthDayNanoWriter(vector) + case (udt: UserDefinedType[_], _) => + createFieldWriter(vector, udt.sqlType, metadata) case (dt, _) => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseSparkDataTypeRoundtripTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseSparkDataTypeRoundtripTest.java index 400551a2..e2088c65 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseSparkDataTypeRoundtripTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseSparkDataTypeRoundtripTest.java @@ -13,6 +13,10 @@ */ package org.lance.spark; +import org.apache.spark.ml.linalg.DenseVector; +import org.apache.spark.ml.linalg.SparseVector; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.VectorUDT; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -41,6 +45,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertTrue; /** @@ -402,4 +407,94 @@ public void testBinaryRoundtrip() { assertArrayEquals(b1, (byte[]) out.get(1).get(1)); assertTrue(out.get(2).isNullAt(1)); } + + // ---------------- UDT (UserDefinedType) ---------------- + + /** + * VectorUDT is MLlib's {@link org.apache.spark.ml.linalg.VectorUDT} — the most common UDT in the + * Spark ecosystem (used for ML feature vectors). Its {@code sqlType} is a 4-field StructType: + * {@code (type: ByteType, size: IntegerType, indices: ArrayType(IntegerType), values: + * ArrayType(DoubleType))}. + * + *

Arrow has no UDT concept, so the read-back schema loses the UDT wrapper and returns a plain + * {@code StructType}. Data values are intact; {@link VectorUDT#deserialize(Object)} expects an + * {@code InternalRow}, so in a test context (where Spark hands back public {@link Row}s) the + * helper reconstructs vectors manually. + * + *

Note: a null VectorUDT row is intentionally omitted. VectorUDT's sqlType marks the inner + * {@code type} field as non-nullable; when the parent struct is null, the StructWriter writes + * null placeholders to all children — including non-nullable ones — and Lance's native writer + * rejects the batch. This is a known limitation of structs with non-nullable children. + */ + @Test + public void testVectorUDTRoundtrip() { + VectorUDT vectorUDT = new VectorUDT(); + StructType schema = + new StructType().add("id", DataTypes.IntegerType, false).add("vec", vectorUDT, true); + + Vector dense = new DenseVector(new double[] {1.0, 2.0, 3.0}); + Vector sparse = new SparseVector(3, new int[] {0, 2}, new double[] {1.0, 3.0}); + + List data = Arrays.asList(RowFactory.create(0, dense), RowFactory.create(1, sparse)); + + Dataset result = writeAndRead(schema, data, "vector_udt"); + + // Read-back schema must lose the UDT wrapper — Arrow has no UDT concept, so the column comes + // back as VectorUDT's sqlType (a plain struct). Lock that contract here so a future change + // that accidentally preserves UDT in metadata is caught loudly. + assertFalse( + result.schema().apply("vec").dataType() instanceof VectorUDT, + "read-back column must not carry VectorUDT; expected the underlying struct sqlType"); + assertInstanceOf( + StructType.class, + result.schema().apply("vec").dataType(), + "read-back column must be VectorUDT.sqlType (StructType)"); + + List out = result.orderBy("id").collectAsList(); + assertEquals(2, out.size()); + + // Reconstruct vectors manually since VectorUDT.deserialize() expects InternalRow. + assertEquals(dense, reconstructVector(out.get(0).getStruct(1))); + assertEquals(sparse, reconstructVector(out.get(1).getStruct(1))); + } + + /** + * Reconstruct an MLlib {@link Vector} from the struct row returned by Lance read-back. The struct + * follows VectorUDT's sqlType: (type: byte, size: int, indices: array<int>, values: + * array<double>). Type 0 = sparse, type 1 = dense. + */ + private static Vector reconstructVector(Row struct) { + byte type = struct.getByte(0); + switch (type) { + case 1: + { + // Dense vector — values at index 3 + List vals = struct.getList(3); + double[] arr = new double[vals.size()]; + for (int i = 0; i < arr.length; i++) { + arr[i] = vals.get(i); + } + return new DenseVector(arr); + } + case 0: + { + // Sparse vector — size at 1, indices at 2, values at 3 + int size = struct.getInt(1); + List idxList = struct.getList(2); + List valList = struct.getList(3); + int[] indices = new int[idxList.size()]; + double[] values = new double[valList.size()]; + for (int i = 0; i < indices.length; i++) { + indices[i] = idxList.get(i); + } + for (int i = 0; i < values.length; i++) { + values[i] = valList.get(i); + } + return new SparseVector(size, indices, values); + } + default: + throw new IllegalArgumentException( + "unknown VectorUDT type byte: " + type + " (expected 0=sparse or 1=dense)"); + } + } } diff --git a/lance-spark-base_2.13/pom.xml b/lance-spark-base_2.13/pom.xml index d99189c3..de0f38d0 100644 --- a/lance-spark-base_2.13/pom.xml +++ b/lance-spark-base_2.13/pom.xml @@ -25,6 +25,12 @@ spark-sql_${scala.compat.version} provided + + org.apache.spark + spark-mllib_${scala.compat.version} + ${spark.version} + test +