Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lance-spark-3.4_2.12/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
<version>${spark34.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.compat.version}</artifactId>
<version>${spark34.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.lance</groupId>
<artifactId>lance-spark-base_${scala.compat.version}</artifactId>
Expand Down
6 changes: 6 additions & 0 deletions lance-spark-3.4_2.13/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
<version>${spark34.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.compat.version}</artifactId>
<version>${spark34.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.lance</groupId>
<artifactId>lance-spark-base_${scala.compat.version}</artifactId>
Expand Down
6 changes: 6 additions & 0 deletions lance-spark-3.5_2.12/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
<version>${spark35.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.compat.version}</artifactId>
<version>${spark35.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.lance</groupId>
<artifactId>lance-spark-base_${scala.compat.version}</artifactId>
Expand Down
6 changes: 6 additions & 0 deletions lance-spark-3.5_2.13/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
<version>${spark35.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.compat.version}</artifactId>
<version>${spark35.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.lance</groupId>
<artifactId>lance-spark-base_${scala.compat.version}</artifactId>
Expand Down
6 changes: 6 additions & 0 deletions lance-spark-4.0_2.13/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
<version>${spark40.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.compat.version}</artifactId>
<version>${spark40.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.lance</groupId>
<artifactId>lance-spark-base_${scala.compat.version}</artifactId>
Expand Down
6 changes: 6 additions & 0 deletions lance-spark-4.1_2.13/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
<version>${spark41.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.compat.version}</artifactId>
<version>${spark41.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.lance</groupId>
<artifactId>lance-spark-base_${scala.compat.version}</artifactId>
Expand Down
6 changes: 6 additions & 0 deletions lance-spark-base_2.12/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
<artifactId>spark-sql_${scala.compat.version}</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.compat.version}</artifactId>
<version>${spark.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ object LanceArrowWriter {
case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector)
case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) =>
new IntervalMonthDayNanoWriter(vector)
case (udt: UserDefinedType[_], _) =>
createFieldWriter(vector, udt.sqlType, metadata)
case (dt, _) =>
throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
*/
package org.lance.spark;

import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
Expand Down Expand Up @@ -41,6 +45,7 @@
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertTrue;

/**
Expand Down Expand Up @@ -402,4 +407,94 @@ public void testBinaryRoundtrip() {
assertArrayEquals(b1, (byte[]) out.get(1).get(1));
assertTrue(out.get(2).isNullAt(1));
}

// ---------------- UDT (UserDefinedType) ----------------

/**
* VectorUDT is MLlib's {@link org.apache.spark.ml.linalg.VectorUDT} — the most common UDT in the
* Spark ecosystem (used for ML feature vectors). Its {@code sqlType} is a 4-field StructType:
* {@code (type: ByteType, size: IntegerType, indices: ArrayType(IntegerType), values:
* ArrayType(DoubleType))}.
*
* <p>Arrow has no UDT concept, so the read-back schema loses the UDT wrapper and returns a plain
* {@code StructType}. Data values are intact; {@link VectorUDT#deserialize(Object)} expects an
* {@code InternalRow}, so in a test context (where Spark hands back public {@link Row}s) the
* helper reconstructs vectors manually.
*
* <p>Note: a null VectorUDT row is intentionally omitted. VectorUDT's sqlType marks the inner
* {@code type} field as non-nullable; when the parent struct is null, the StructWriter writes
* null placeholders to all children — including non-nullable ones — and Lance's native writer
* rejects the batch. This is a known limitation of structs with non-nullable children.
*/
@Test
public void testVectorUDTRoundtrip() {
VectorUDT vectorUDT = new VectorUDT();
StructType schema =
new StructType().add("id", DataTypes.IntegerType, false).add("vec", vectorUDT, true);

Vector dense = new DenseVector(new double[] {1.0, 2.0, 3.0});
Vector sparse = new SparseVector(3, new int[] {0, 2}, new double[] {1.0, 3.0});

List<Row> data = Arrays.asList(RowFactory.create(0, dense), RowFactory.create(1, sparse));

Dataset<Row> result = writeAndRead(schema, data, "vector_udt");

// Read-back schema must lose the UDT wrapper — Arrow has no UDT concept, so the column comes
// back as VectorUDT's sqlType (a plain struct). Lock that contract here so a future change
// that accidentally preserves UDT in metadata is caught loudly.
assertFalse(
result.schema().apply("vec").dataType() instanceof VectorUDT,
"read-back column must not carry VectorUDT; expected the underlying struct sqlType");
assertInstanceOf(
StructType.class,
result.schema().apply("vec").dataType(),
"read-back column must be VectorUDT.sqlType (StructType)");

List<Row> out = result.orderBy("id").collectAsList();
assertEquals(2, out.size());

// Reconstruct vectors manually since VectorUDT.deserialize() expects InternalRow.
assertEquals(dense, reconstructVector(out.get(0).getStruct(1)));
assertEquals(sparse, reconstructVector(out.get(1).getStruct(1)));
}

/**
* Reconstruct an MLlib {@link Vector} from the struct row returned by Lance read-back. The struct
* follows VectorUDT's sqlType: (type: byte, size: int, indices: array&lt;int&gt;, values:
* array&lt;double&gt;). Type 0 = sparse, type 1 = dense.
*/
private static Vector reconstructVector(Row struct) {
byte type = struct.getByte(0);
switch (type) {
case 1:
{
// Dense vector — values at index 3
List<Double> vals = struct.getList(3);
double[] arr = new double[vals.size()];
for (int i = 0; i < arr.length; i++) {
arr[i] = vals.get(i);
}
return new DenseVector(arr);
}
case 0:
{
// Sparse vector — size at 1, indices at 2, values at 3
int size = struct.getInt(1);
List<Integer> idxList = struct.getList(2);
List<Double> valList = struct.getList(3);
int[] indices = new int[idxList.size()];
double[] values = new double[valList.size()];
for (int i = 0; i < indices.length; i++) {
indices[i] = idxList.get(i);
}
for (int i = 0; i < values.length; i++) {
values[i] = valList.get(i);
}
return new SparseVector(size, indices, values);
}
default:
throw new IllegalArgumentException(
"unknown VectorUDT type byte: " + type + " (expected 0=sparse or 1=dense)");
}
}
}
6 changes: 6 additions & 0 deletions lance-spark-base_2.13/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
<artifactId>spark-sql_${scala.compat.version}</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.compat.version}</artifactId>
<version>${spark.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Loading