diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala
index dfe82981..d0d22174 100644
--- a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala
+++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala
@@ -114,6 +114,8 @@ object LanceArrowWriter {
case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector)
case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) =>
new IntervalMonthDayNanoWriter(vector)
+ case (udt: UserDefinedType[_], _) =>
+ createFieldWriter(vector, udt.sqlType, metadata)
case (dt, _) =>
throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}
diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseSparkDataTypeRoundtripTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseSparkDataTypeRoundtripTest.java
index 400551a2..e2088c65 100644
--- a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseSparkDataTypeRoundtripTest.java
+++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseSparkDataTypeRoundtripTest.java
@@ -13,6 +13,10 @@
*/
package org.lance.spark;
+import org.apache.spark.ml.linalg.DenseVector;
+import org.apache.spark.ml.linalg.SparseVector;
+import org.apache.spark.ml.linalg.Vector;
+import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
@@ -41,6 +45,7 @@
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
@@ -402,4 +407,94 @@ public void testBinaryRoundtrip() {
assertArrayEquals(b1, (byte[]) out.get(1).get(1));
assertTrue(out.get(2).isNullAt(1));
}
+
+ // ---------------- UDT (UserDefinedType) ----------------
+
+ /**
+ * VectorUDT is MLlib's {@link org.apache.spark.ml.linalg.VectorUDT} — the most common UDT in the
+ * Spark ecosystem (used for ML feature vectors). Its {@code sqlType} is a 4-field StructType:
+ * {@code (type: ByteType, size: IntegerType, indices: ArrayType(IntegerType), values:
+ * ArrayType(DoubleType))}.
+ *
+ * Arrow has no UDT concept, so the read-back schema loses the UDT wrapper and returns a plain
+ * {@code StructType}. Data values are intact; {@link VectorUDT#deserialize(Object)} expects an
+ * {@code InternalRow}, so in a test context (where Spark hands back public {@link Row}s) the
+ * helper reconstructs vectors manually.
+ *
+ *
Note: a null VectorUDT row is intentionally omitted. VectorUDT's sqlType marks the inner
+ * {@code type} field as non-nullable; when the parent struct is null, the StructWriter writes
+ * null placeholders to all children — including non-nullable ones — and Lance's native writer
+ * rejects the batch. This is a known limitation of structs with non-nullable children.
+ */
+ @Test
+ public void testVectorUDTRoundtrip() {
+ VectorUDT vectorUDT = new VectorUDT();
+ StructType schema =
+ new StructType().add("id", DataTypes.IntegerType, false).add("vec", vectorUDT, true);
+
+ Vector dense = new DenseVector(new double[] {1.0, 2.0, 3.0});
+ Vector sparse = new SparseVector(3, new int[] {0, 2}, new double[] {1.0, 3.0});
+
+ List data = Arrays.asList(RowFactory.create(0, dense), RowFactory.create(1, sparse));
+
+ Dataset result = writeAndRead(schema, data, "vector_udt");
+
+ // Read-back schema must lose the UDT wrapper — Arrow has no UDT concept, so the column comes
+ // back as VectorUDT's sqlType (a plain struct). Lock that contract here so a future change
+ // that accidentally preserves UDT in metadata is caught loudly.
+ assertFalse(
+ result.schema().apply("vec").dataType() instanceof VectorUDT,
+ "read-back column must not carry VectorUDT; expected the underlying struct sqlType");
+ assertInstanceOf(
+ StructType.class,
+ result.schema().apply("vec").dataType(),
+ "read-back column must be VectorUDT.sqlType (StructType)");
+
+ List out = result.orderBy("id").collectAsList();
+ assertEquals(2, out.size());
+
+ // Reconstruct vectors manually since VectorUDT.deserialize() expects InternalRow.
+ assertEquals(dense, reconstructVector(out.get(0).getStruct(1)));
+ assertEquals(sparse, reconstructVector(out.get(1).getStruct(1)));
+ }
+
+ /**
+ * Reconstruct an MLlib {@link Vector} from the struct row returned by Lance read-back. The struct
+ * follows VectorUDT's sqlType: (type: byte, size: int, indices: array<int>, values:
+ * array<double>). Type 0 = sparse, type 1 = dense.
+ */
+ private static Vector reconstructVector(Row struct) {
+ byte type = struct.getByte(0);
+ switch (type) {
+ case 1:
+ {
+ // Dense vector — values at index 3
+ List vals = struct.getList(3);
+ double[] arr = new double[vals.size()];
+ for (int i = 0; i < arr.length; i++) {
+ arr[i] = vals.get(i);
+ }
+ return new DenseVector(arr);
+ }
+ case 0:
+ {
+ // Sparse vector — size at 1, indices at 2, values at 3
+ int size = struct.getInt(1);
+ List idxList = struct.getList(2);
+ List valList = struct.getList(3);
+ int[] indices = new int[idxList.size()];
+ double[] values = new double[valList.size()];
+ for (int i = 0; i < indices.length; i++) {
+ indices[i] = idxList.get(i);
+ }
+ for (int i = 0; i < values.length; i++) {
+ values[i] = valList.get(i);
+ }
+ return new SparseVector(size, indices, values);
+ }
+ default:
+ throw new IllegalArgumentException(
+ "unknown VectorUDT type byte: " + type + " (expected 0=sparse or 1=dense)");
+ }
+ }
}
diff --git a/lance-spark-base_2.13/pom.xml b/lance-spark-base_2.13/pom.xml
index d99189c3..de0f38d0 100644
--- a/lance-spark-base_2.13/pom.xml
+++ b/lance-spark-base_2.13/pom.xml
@@ -25,6 +25,12 @@
spark-sql_${scala.compat.version}
provided
+
+ org.apache.spark
+ spark-mllib_${scala.compat.version}
+ ${spark.version}
+ test
+