From 22aaac06a109d53c9ced67578134586732e47c38 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Tue, 16 Dec 2025 16:16:22 -0500 Subject: [PATCH 01/32] Single value schema works --- Cargo.lock | 3 + Cargo.toml | 1 + src/lib.rs | 2 + src/variant_schema.rs | 425 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 431 insertions(+) create mode 100644 src/variant_schema.rs diff --git a/Cargo.lock b/Cargo.lock index 4e52035..cf7bf68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -684,7 +684,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ "iana-time-zone", + "js-sys", "num-traits", + "wasm-bindgen", "windows-link", ] @@ -1643,6 +1645,7 @@ dependencies = [ "arrow", "arrow-cast", "arrow-schema", + "chrono", "datafusion", "datafusion-sqllogictest", "env_logger", diff --git a/Cargo.toml b/Cargo.toml index 769b36c..ccd6e2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ arrow-schema = { git = "https://github.com/apache/arrow-rs", rev = "ca4a0ae5e412 parquet-variant-compute = { git = "https://github.com/apache/arrow-rs", rev = "ca4a0ae5e4122e905686f3b7538b5308503cb770" } parquet-variant-json = { git = "https://github.com/apache/arrow-rs", rev = "ca4a0ae5e4122e905686f3b7538b5308503cb770" } parquet-variant = { git = "https://github.com/apache/arrow-rs", rev = "ca4a0ae5e4122e905686f3b7538b5308503cb770" } +chrono = "0.4.42" [patch.crates-io] arrow = { git = "https://github.com/apache/arrow-rs", rev = "ca4a0ae5e4122e905686f3b7538b5308503cb770" } diff --git a/src/lib.rs b/src/lib.rs index 236b554..853eef0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ mod variant_list_insert; mod variant_object_construct; mod variant_object_insert; mod variant_pretty; +mod variant_schema; mod variant_to_json; pub use cast_to_variant::*; @@ -22,4 +23,5 @@ pub use variant_list_insert::*; pub use variant_object_construct::*; pub use variant_object_insert::*; pub use variant_pretty::*; +pub use variant_schema::*; pub use variant_to_json::*; diff --git a/src/variant_schema.rs b/src/variant_schema.rs new file mode 100644 index 0000000..9773970 --- /dev/null +++ b/src/variant_schema.rs @@ -0,0 +1,425 @@ +use std::collections::BTreeMap; + +use arrow::array::AsArray; +use arrow_schema::DataType; +use datafusion::{ + common::{exec_datafusion_err, exec_err}, + error::Result, + logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility}, + scalar::ScalarValue, +}; +use parquet_variant::Variant; +use parquet_variant_compute::VariantArray; + +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct VariantSchemaUDF { + signature: Signature, +} + +impl Default for VariantSchemaUDF { + fn default() -> Self { + Self { + signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +enum VariantSchema { + Primitive(String), + Object(BTreeMap), + Array(Box), + Variant, +} + +fn schema_from_variant(v: &Variant) -> VariantSchema { + match v { + Variant::Object(obj) => { + let mut fields = BTreeMap::new(); + for (k, v) in obj.iter() { + fields.insert(k.to_string(), schema_from_variant(&v)); + } + VariantSchema::Object(fields) + } + Variant::List(list) => { + let mut schemas: Vec = + list.iter().map(|v| schema_from_variant(v)).collect(); + + schemas.sort(); + schemas.dedup(); + + if schemas.len() == 1 { + VariantSchema::Array(Box::new(schemas.pop().unwrap())) + } else { + VariantSchema::Array(Box::new(VariantSchema::Variant)) + } + } + // primitives + _ => VariantSchema::Primitive(variant_schema_str(v)) + } + } + +fn variant_schema_str<'m, 'v>(v: &Variant<'m, 'v>) -> String { + match v { + Variant::Null => "NULL".to_string(), + Variant::Int8(_) => "INT(8, SIGNED)".to_string(), + Variant::Int16(_) => "INT(16, SIGNED)".to_string(), + Variant::Int32(_) => "INT(32, SIGNED)".to_string(), + Variant::Int64(_) => "INT(64, SIGNED)".to_string(), + Variant::Float(_) => "FLOAT".to_string(), + Variant::Double(_) => "DOUBLE".to_string(), + Variant::Decimal4(d) => { + format!("DECIMAL({}, {})", d.integer().to_string().len(), d.scale()) + } + Variant::Decimal8(d) => { + format!("DECIMAL({}, {})", d.integer().to_string().len(), d.scale()) + } + Variant::Decimal16(d) => { + format!("DECIMAL({}, {})", d.integer().to_string().len(), d.scale()) + } + Variant::BooleanTrue | Variant::BooleanFalse => "BOOLEAN".to_string(), + Variant::String(_) | Variant::ShortString(_) => "STRING".to_string(), + Variant::Binary(_) => "BINARY".to_string(), + Variant::Date(_) => "DATE".to_string(), + Variant::Time(_) => "TIME".to_string(), + Variant::TimestampMicros(_) => "TIMESTAMP(isAdjustedToUTC=true, MICROS)".to_string(), + Variant::TimestampNtzMicros(_) => "TIMESTAMP(isAdjustedToUTC=false, MICROS)".to_string(), + Variant::TimestampNanos(_) => "TIMESTAMP(isAdjustedToUTC=true, NANOS)".to_string(), + Variant::TimestampNtzNanos(_) => "TIMESTAMP(isAdjustedToUTC=false, NANOS)".to_string(), + Variant::Uuid(_) => "UUID".to_string(), + + Variant::Object(obj) => { + let fields: Vec = obj + .iter() + .map(|(k, v)| format!("{k}: {}", variant_schema_str(&v))) + .collect(); + format!("OBJECT<{}>", fields.join(", ")) + } + + Variant::List(list) => { + let mut item_types: Vec = list.iter().map(|v| variant_schema_str(&v)).collect(); + item_types.sort(); + item_types.dedup(); + let array_type = if item_types.len() == 1 { + item_types[0].clone() + } else { + "VARIANT".to_string() + }; + format!("ARRAY<{array_type}>") + } + } +} + +fn infer_variant_schema(variant: &ColumnarValue) -> Result { + match variant { + ColumnarValue::Scalar(scalar) => { + let ScalarValue::Struct(struct_array) = scalar else { + return exec_err!("Unsupported data type: {}", scalar.data_type()); + }; + let variant_array = VariantArray::try_new(struct_array.as_ref())?; + let v = variant_array.value(0); + let schema_str = variant_schema_str(&v); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + schema_str, + )))) + } + ColumnarValue::Array(arr) => { + let variant_array = + VariantArray::try_new(arr.as_struct()).expect("Expect VariantArray"); + let mut item_types: Vec = variant_array + .iter() + .filter_map(|v| v.as_ref().map(|v| variant_schema_str(v))) + .collect(); + item_types.sort(); + item_types.dedup(); + let array_type = if item_types.len() == 1 { + item_types[0].clone() + } else { + "VARIANT".to_string() + }; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(format!( + "ARRAY<{array_type}>" + ))))) + } + } +} + +impl ScalarUDFImpl for VariantSchemaUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "variant_schema" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args( + &self, + args: datafusion::logical_expr::ScalarFunctionArgs, + ) -> Result { + let arg = args + .args + .first() + .ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?; + infer_variant_schema(arg) + } +} + +#[cfg(test)] +mod tests { + use arrow::array::StructArray; + use arrow_schema::{DataType, Field, Fields}; + use chrono::{DateTime, NaiveDate, NaiveTime}; + use datafusion::{ + logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}, + scalar::ScalarValue, + }; + use parquet_variant::{Variant, VariantDecimal4}; + use parquet_variant_compute::{VariantArray, VariantType}; + use std::sync::Arc; + + use crate::{VariantSchemaUDF, shared::{build_variant_array_from_json, build_variant_array_from_json_array}}; + + fn build_scalar_udf_args(struct_array: StructArray) -> ScalarFunctionArgs { + let return_field = Arc::new(Field::new("result", DataType::Utf8View, true)); + let arg_field = Arc::new( + Field::new("input", DataType::Struct(Fields::empty()), true) + .with_extension_type(VariantType), + ); + ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Struct(Arc::new( + struct_array, + )))], + arg_fields: vec![arg_field], + number_rows: Default::default(), + return_field, + config_options: Default::default(), + } + } + + fn build_array_udf_args(struct_array: StructArray) -> ScalarFunctionArgs { + let return_field = Arc::new(Field::new("result", DataType::Utf8View, true)); + let arg_field = Arc::new( + Field::new("input", DataType::Struct(Fields::empty()), true) + .with_extension_type(VariantType), + ); + ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Struct(Arc::new( + struct_array, + )))], + arg_fields: vec![arg_field], + number_rows: Default::default(), + return_field, + config_options: Default::default(), + } + } + + #[test] + fn test_get_single_typed_null_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = Variant::Null; + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "NULL") + } + + #[test] + fn test_get_single_typed_int32_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = Variant::from(1234i32); + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "INT(32, SIGNED)") + } + + #[test] + fn test_get_single_typed_date_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = Variant::from(NaiveDate::from_ymd_opt(1990, 1, 1).expect("Expect NaiveDate")); + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "DATE") + } + + #[test] + fn test_get_single_typed_timestamp_micro_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = + Variant::from(DateTime::from_timestamp(1431648000, 0).expect("Expect TimeStamp")); + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "TIMESTAMP(isAdjustedToUTC=true, MICROS)") + } + + #[test] + fn test_get_single_typed_decimal_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = Variant::Decimal4(VariantDecimal4::try_new(1234, 1).expect("Expect decimal")); + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "DECIMAL(4, 1)") + } + + #[test] + fn test_get_single_typed_float_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = Variant::from(123.4f32); + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "FLOAT") + } + + #[test] + fn test_get_single_typed_double_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = Variant::from(123.4f64); + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "DOUBLE") + } + + #[test] + fn test_get_single_typed_bool_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = Variant::BooleanTrue; + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "BOOLEAN") + } + + #[test] + fn test_get_single_typed_binary_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = Variant::Binary(&[1u8, 2, 3]); + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "BINARY") + } + + #[test] + fn test_get_single_typed_string_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = Variant::from("foo"); + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "STRING") + } + + #[test] + fn test_get_single_typed_time_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant = Variant::from(NaiveTime::from_hms_opt(0, 0, 0).expect("Expect NaiveTime")); + let variant_array = VariantArray::from_iter(vec![variant]); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "TIME") + } + + #[test] + fn test_get_single_struct_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant_array = build_variant_array_from_json(&serde_json::json!({ + "key": 123, "data": [4, 5] + })); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!( + schema, + "OBJECT, key: INT(8, SIGNED)>" + ) + } + + #[test] + fn test_get_single_struct_variant_conflicting_schema() { + let udf = VariantSchemaUDF::default(); + let variant_array = build_variant_array_from_json(&serde_json::json!({ + "data": [{"a":"a"}, 5] + })); + let struct_array = variant_array.into_inner(); + let args = build_scalar_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "OBJECT>") + } + + #[test] + fn test_get_array_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant_array = build_variant_array_from_json_array(&[Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), None, Some(serde_json::json!({"wing": 123}))]); + let struct_array = variant_array.into_inner(); + let args = build_array_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "OBJECT>") + } +} From 8eea36a3d39401ada719597c6e511e2a448e1adb Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Wed, 17 Dec 2025 17:03:05 -0500 Subject: [PATCH 02/32] ehh... --- src/variant_schema.rs | 345 +++++++++++++++++++++++++++++++----------- 1 file changed, 258 insertions(+), 87 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 9773970..b4f3593 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -1,10 +1,10 @@ use std::collections::BTreeMap; use arrow::array::AsArray; -use arrow_schema::DataType; +use arrow_schema::{DataType, TimeUnit}; use datafusion::{ - common::{exec_datafusion_err, exec_err}, - error::Result, + common::exec_err, + error::{DataFusionError, Result}, logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility}, scalar::ScalarValue, }; @@ -24,89 +24,237 @@ impl Default for VariantSchemaUDF { } } -#[derive(Clone, Debug, PartialEq, Eq)] +/// For schema_from_variant Schema representation +/// there are 4 possible types available depending on the Variant +/// For ColumnarValue::Scalar: +/// Primitive -> Just the corresponding SQL data type based on Variant type +/// Array -> Should be an > given that type is the same, if not > +/// Object -> Each object field should keep track of its type > +/// +/// For ColumnarValue::Array we get the type for each individual Variant value and compare it with the rest. +/// - If one of the values is type => we can early terminate and call everything +/// +/// For different Variant types we will use different ways of keeping track of types between the rows: +/// - If the outer/inner type differs, we call it and early terminate this level. +/// - Primitive -> create a set to store Primitive types and if the set.len() > 1, we call everything +/// - Array -> array type is flat, so the same implementation as Primitive should work for this Array type +/// - Object -> is the difficult one in this case. Different rows can include or exlude certain fields, and we +/// need to keep track of each field between different rows separately. To keep it the fields sorted we will use +/// a BTree with fields as keys and set of types as values. We will add values to each field's set and if set.len() +/// \> 1, we call this field and we no longer need to keep track of it. +/// +/// +/// Later we should also implement Databricks' coersion into a similiar type: +/// > "The schema of each VARIANT value is merged together by field name. When two fields with the same name have +/// > a different type across records, Databricks uses the least common type. When no such type exists, the type +/// > is derived as a VARIANT. For example, INT and DOUBLE become DOUBLE, while TIMESTAMP and STRING become VARIANT." \ +/// > https://docs.databricks.com/gcp/en/sql/language-manual/functions/schema_of_variant_agg +/// +#[derive(Debug, PartialEq, Eq, Clone)] enum VariantSchema { - Primitive(String), - Object(BTreeMap), + Primitive(PrimitiveType), Array(Box), + Object(BTreeMap), Variant, } +#[derive(Debug, PartialEq, Eq, Clone)] +enum PrimitiveType { + Int { bits: u8, signed: bool }, + Float, + Double, + Decimal { precision: u8, scale: u8 }, + Boolean, + String, + Binary, + Date, + Time, + Timestamp { utc: bool, unit: TimeUnit }, + Uuid, + Null, +} + +/// This function extracts the schema from a single Variant scalar fn schema_from_variant(v: &Variant) -> VariantSchema { match v { Variant::Object(obj) => { - let mut fields = BTreeMap::new(); - for (k, v) in obj.iter() { - fields.insert(k.to_string(), schema_from_variant(&v)); - } + let fields = obj + .iter() + .map(|(k, v)| (k.to_string(), schema_from_variant(&v))) + .collect(); + VariantSchema::Object(fields) } + Variant::List(list) => { - let mut schemas: Vec = - list.iter().map(|v| schema_from_variant(v)).collect(); - - schemas.sort(); - schemas.dedup(); - - if schemas.len() == 1 { - VariantSchema::Array(Box::new(schemas.pop().unwrap())) - } else { - VariantSchema::Array(Box::new(VariantSchema::Variant)) - } - } - // primitives - _ => VariantSchema::Primitive(variant_schema_str(v)) + let inner = list + .iter() + .map(|v| schema_from_variant(&v)) + .reduce(merge_variant_schema) + .unwrap_or(VariantSchema::Variant); + + VariantSchema::Array(Box::new(inner)) } + // primitives + _ => VariantSchema::Primitive(primitive_from_variant(v)), } +} -fn variant_schema_str<'m, 'v>(v: &Variant<'m, 'v>) -> String { - match v { - Variant::Null => "NULL".to_string(), - Variant::Int8(_) => "INT(8, SIGNED)".to_string(), - Variant::Int16(_) => "INT(16, SIGNED)".to_string(), - Variant::Int32(_) => "INT(32, SIGNED)".to_string(), - Variant::Int64(_) => "INT(64, SIGNED)".to_string(), - Variant::Float(_) => "FLOAT".to_string(), - Variant::Double(_) => "DOUBLE".to_string(), - Variant::Decimal4(d) => { - format!("DECIMAL({}, {})", d.integer().to_string().len(), d.scale()) - } - Variant::Decimal8(d) => { - format!("DECIMAL({}, {})", d.integer().to_string().len(), d.scale()) - } - Variant::Decimal16(d) => { - format!("DECIMAL({}, {})", d.integer().to_string().len(), d.scale()) +fn schema_to_string(schema: &VariantSchema) -> String { + match schema { + VariantSchema::Primitive(s) => primitive_to_string(s), + + VariantSchema::Variant => "VARIANT".to_string(), + + VariantSchema::Array(inner) => { + format!("ARRAY<{}>", schema_to_string(inner)) } - Variant::BooleanTrue | Variant::BooleanFalse => "BOOLEAN".to_string(), - Variant::String(_) | Variant::ShortString(_) => "STRING".to_string(), - Variant::Binary(_) => "BINARY".to_string(), - Variant::Date(_) => "DATE".to_string(), - Variant::Time(_) => "TIME".to_string(), - Variant::TimestampMicros(_) => "TIMESTAMP(isAdjustedToUTC=true, MICROS)".to_string(), - Variant::TimestampNtzMicros(_) => "TIMESTAMP(isAdjustedToUTC=false, MICROS)".to_string(), - Variant::TimestampNanos(_) => "TIMESTAMP(isAdjustedToUTC=true, NANOS)".to_string(), - Variant::TimestampNtzNanos(_) => "TIMESTAMP(isAdjustedToUTC=false, NANOS)".to_string(), - Variant::Uuid(_) => "UUID".to_string(), - Variant::Object(obj) => { - let fields: Vec = obj + VariantSchema::Object(fields) => { + let parts: Vec = fields .iter() - .map(|(k, v)| format!("{k}: {}", variant_schema_str(&v))) + .map(|(k, v)| format!("{k}: {}", schema_to_string(v))) .collect(); - format!("OBJECT<{}>", fields.join(", ")) + format!("OBJECT<{}>", parts.join(", ")) } + } +} - Variant::List(list) => { - let mut item_types: Vec = list.iter().map(|v| variant_schema_str(&v)).collect(); - item_types.sort(); - item_types.dedup(); - let array_type = if item_types.len() == 1 { - item_types[0].clone() - } else { - "VARIANT".to_string() - }; - format!("ARRAY<{array_type}>") +fn primitive_to_string(p: &PrimitiveType) -> String { + match p { + PrimitiveType::Int { bits, signed } => format!( + "INT({bits}, {})", + if *signed { "SIGNED" } else { "UNSIGNED" } + ), + PrimitiveType::Float => "FLOAT".to_string(), + PrimitiveType::Double => "DOUBLE".to_string(), + PrimitiveType::Decimal { precision, scale } => format!("DECIMAL({precision}, {scale})"), + PrimitiveType::Boolean => "BOOLEAN".to_string(), + PrimitiveType::String => "STRING".to_string(), + PrimitiveType::Binary => "BINARY".to_string(), + PrimitiveType::Date => "DATE".to_string(), + PrimitiveType::Time => "TIME".to_string(), + PrimitiveType::Timestamp { utc, unit } => { + format!("TIMESTAMP(isAdjustedToUTC={utc}, {unit:?})") + } + PrimitiveType::Uuid => "UUID".to_string(), + PrimitiveType::Null => "NULL".to_string(), + } +} + +fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> PrimitiveType { + match v { + Variant::Null => PrimitiveType::Null, + Variant::Int8(_) => PrimitiveType::Int { + bits: 8, + signed: true, + }, + Variant::Int16(_) => PrimitiveType::Int { + bits: 16, + signed: true, + }, + Variant::Int32(_) => PrimitiveType::Int { + bits: 32, + signed: true, + }, + Variant::Int64(_) => PrimitiveType::Int { + bits: 64, + signed: true, + }, + Variant::Float(_) => PrimitiveType::Float, + Variant::Double(_) => PrimitiveType::Double, + Variant::Decimal4(d) => PrimitiveType::Decimal { + precision: d.integer().to_string().len() as u8, + scale: d.scale(), + }, + Variant::Decimal8(d) => PrimitiveType::Decimal { + precision: d.integer().to_string().len() as u8, + scale: d.scale(), + }, + Variant::Decimal16(d) => PrimitiveType::Decimal { + precision: d.integer().to_string().len() as u8, + scale: d.scale(), + }, + Variant::BooleanTrue | Variant::BooleanFalse => PrimitiveType::Boolean, + Variant::String(_) | Variant::ShortString(_) => PrimitiveType::String, + Variant::Binary(_) => PrimitiveType::Binary, + Variant::Date(_) => PrimitiveType::Date, + Variant::Time(_) => PrimitiveType::Time, + Variant::TimestampMicros(_) => PrimitiveType::Timestamp { + utc: true, + unit: TimeUnit::Microsecond, + }, + Variant::TimestampNtzMicros(_) => PrimitiveType::Timestamp { + utc: false, + unit: TimeUnit::Microsecond, + }, + Variant::TimestampNanos(_) => PrimitiveType::Timestamp { + utc: true, + unit: TimeUnit::Nanosecond, + }, + Variant::TimestampNtzNanos(_) => PrimitiveType::Timestamp { + utc: false, + unit: TimeUnit::Nanosecond, + }, + Variant::Uuid(_) => PrimitiveType::Uuid, + _ => unreachable!("Should be only applied to Primitive Variant"), + } +} + +fn merge_primitives(a: PrimitiveType, b: PrimitiveType) -> Option { + use PrimitiveType::*; + + match (a, b) { + // null handling + (Null, x) | (x, Null) => Some(x), + // normal case + (x, y) if x == y => Some(x), + // numeric widening + (Int { .. }, Double) | (Double, Int { .. }) => Some(Double), + (Int { .. }, Float) | (Float, Int { .. }) => Some(Float), + (Float, Double) | (Double, Float) => Some(Double), + + // decimal rules (simplified) + ( + Decimal { + precision: p1, + scale: s1, + }, + Decimal { + precision: p2, + scale: s2, + }, + ) => Some(Decimal { + precision: p1.max(p2), + scale: s1.max(s2), + }), + + _ => None, + } +} + +fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { + use VariantSchema::*; + + match (a, b) { + (Variant, _) | (_, Variant) => Variant, + + (Primitive(p1), Primitive(p2)) => { + merge_primitives(p1, p2).map(Primitive).unwrap_or(Variant) } + + (Array(a), Array(b)) => Array(Box::new(merge_variant_schema(*a, *b))), + + (Object(mut a), Object(b)) => { + for (k, v_b) in b { + a.entry(k) + .and_modify(|v_a| *v_a = merge_variant_schema(v_a.clone(), v_b.clone())) + .or_insert(v_b); + } + Object(a) + } + + _ => Variant, } } @@ -118,7 +266,10 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { }; let variant_array = VariantArray::try_new(struct_array.as_ref())?; let v = variant_array.value(0); - let schema_str = variant_schema_str(&v); + + let schema = schema_from_variant(&v); + let schema_str = schema_to_string(&schema); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( schema_str, )))) @@ -126,20 +277,17 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { ColumnarValue::Array(arr) => { let variant_array = VariantArray::try_new(arr.as_struct()).expect("Expect VariantArray"); - let mut item_types: Vec = variant_array + + let final_schema = variant_array .iter() - .filter_map(|v| v.as_ref().map(|v| variant_schema_str(v))) - .collect(); - item_types.sort(); - item_types.dedup(); - let array_type = if item_types.len() == 1 { - item_types[0].clone() - } else { - "VARIANT".to_string() - }; - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(format!( - "ARRAY<{array_type}>" - ))))) + .flatten() + .map(|v| schema_from_variant(&v)) + .reduce(merge_variant_schema) + .unwrap_or(VariantSchema::Variant); + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + schema_to_string(&final_schema), + )))) } } } @@ -165,10 +313,9 @@ impl ScalarUDFImpl for VariantSchemaUDF { &self, args: datafusion::logical_expr::ScalarFunctionArgs, ) -> Result { - let arg = args - .args - .first() - .ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?; + let arg = args.args.first().ok_or_else(|| { + DataFusionError::Execution("empty argument, expected 1 argument".to_string()) + })?; infer_variant_schema(arg) } } @@ -186,7 +333,10 @@ mod tests { use parquet_variant_compute::{VariantArray, VariantType}; use std::sync::Arc; - use crate::{VariantSchemaUDF, shared::{build_variant_array_from_json, build_variant_array_from_json_array}}; + use crate::{ + VariantSchemaUDF, + shared::{build_variant_array_from_json, build_variant_array_from_json_array}, + }; fn build_scalar_udf_args(struct_array: StructArray) -> ScalarFunctionArgs { let return_field = Arc::new(Field::new("result", DataType::Utf8View, true)); @@ -276,7 +426,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "TIMESTAMP(isAdjustedToUTC=true, MICROS)") + assert_eq!(schema, "TIMESTAMP(isAdjustedToUTC=true, Microsecond)") } #[test] @@ -413,13 +563,34 @@ mod tests { #[test] fn test_get_array_variant_schema() { let udf = VariantSchemaUDF::default(); - let variant_array = build_variant_array_from_json_array(&[Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), None, Some(serde_json::json!({"wing": 123}))]); + let variant_array = build_variant_array_from_json_array(&[ + Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), + None, + Some(serde_json::json!({"wing": {"ding": "man"}})), + ]); let struct_array = variant_array.into_inner(); let args = build_array_udf_args(struct_array); let result = udf.invoke_with_args(args).unwrap(); let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "OBJECT>") + assert_eq!(schema, "OBJECT>") + } + + #[test] + fn test_get_array_variant_conflicting_schema() { + let udf = VariantSchemaUDF::default(); + let variant_array = build_variant_array_from_json_array(&[ + Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), + None, + Some(serde_json::json!({"wing": 123})), + ]); + let struct_array = variant_array.into_inner(); + let args = build_array_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + panic!() + }; + assert_eq!(schema, "OBJECT") } } From 2563ce830ea01c0122ce58f7a0ca1891d6b913b5 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Thu, 18 Dec 2025 15:21:22 -0500 Subject: [PATCH 03/32] sort of works --- src/variant_schema.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index b4f3593..c776c4f 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -1,5 +1,4 @@ use std::collections::BTreeMap; - use arrow::array::AsArray; use arrow_schema::{DataType, TimeUnit}; use datafusion::{ @@ -201,6 +200,7 @@ fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> PrimitiveType { } } +// Todo: needs more work on type coercing fn merge_primitives(a: PrimitiveType, b: PrimitiveType) -> Option { use PrimitiveType::*; @@ -235,7 +235,6 @@ fn merge_primitives(a: PrimitiveType, b: PrimitiveType) -> Option fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { use VariantSchema::*; - match (a, b) { (Variant, _) | (_, Variant) => Variant, @@ -362,9 +361,9 @@ mod tests { .with_extension_type(VariantType), ); ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::Struct(Arc::new( + args: vec![ColumnarValue::Array(Arc::new( struct_array, - )))], + ))], arg_fields: vec![arg_field], number_rows: Default::default(), return_field, @@ -582,7 +581,7 @@ mod tests { let udf = VariantSchemaUDF::default(); let variant_array = build_variant_array_from_json_array(&[ Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), - None, + // None, Some(serde_json::json!({"wing": 123})), ]); let struct_array = variant_array.into_inner(); From 4bb2f5a481fdbd311f4da32b055c9959e0c374a4 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Thu, 18 Dec 2025 16:11:25 -0500 Subject: [PATCH 04/32] replaced primitive enum with arrow DataType --- src/variant_schema.rs | 195 +++++++++++++++--------------------------- 1 file changed, 68 insertions(+), 127 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index c776c4f..c09c44b 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap}; use arrow::array::AsArray; use arrow_schema::{DataType, TimeUnit}; use datafusion::{ @@ -51,28 +51,12 @@ impl Default for VariantSchemaUDF { /// #[derive(Debug, PartialEq, Eq, Clone)] enum VariantSchema { - Primitive(PrimitiveType), + Primitive(DataType), Array(Box), Object(BTreeMap), Variant, } -#[derive(Debug, PartialEq, Eq, Clone)] -enum PrimitiveType { - Int { bits: u8, signed: bool }, - Float, - Double, - Decimal { precision: u8, scale: u8 }, - Boolean, - String, - Binary, - Date, - Time, - Timestamp { utc: bool, unit: TimeUnit }, - Uuid, - Null, -} - /// This function extracts the schema from a single Variant scalar fn schema_from_variant(v: &Variant) -> VariantSchema { match v { @@ -101,7 +85,7 @@ fn schema_from_variant(v: &Variant) -> VariantSchema { fn schema_to_string(schema: &VariantSchema) -> String { match schema { - VariantSchema::Primitive(s) => primitive_to_string(s), + VariantSchema::Primitive(s) => format!("{s}"), VariantSchema::Variant => "VARIANT".to_string(), @@ -119,115 +103,72 @@ fn schema_to_string(schema: &VariantSchema) -> String { } } -fn primitive_to_string(p: &PrimitiveType) -> String { - match p { - PrimitiveType::Int { bits, signed } => format!( - "INT({bits}, {})", - if *signed { "SIGNED" } else { "UNSIGNED" } - ), - PrimitiveType::Float => "FLOAT".to_string(), - PrimitiveType::Double => "DOUBLE".to_string(), - PrimitiveType::Decimal { precision, scale } => format!("DECIMAL({precision}, {scale})"), - PrimitiveType::Boolean => "BOOLEAN".to_string(), - PrimitiveType::String => "STRING".to_string(), - PrimitiveType::Binary => "BINARY".to_string(), - PrimitiveType::Date => "DATE".to_string(), - PrimitiveType::Time => "TIME".to_string(), - PrimitiveType::Timestamp { utc, unit } => { - format!("TIMESTAMP(isAdjustedToUTC={utc}, {unit:?})") - } - PrimitiveType::Uuid => "UUID".to_string(), - PrimitiveType::Null => "NULL".to_string(), +fn decimal_precision>(val: T) -> u8 { + let mut n = val.into(); + if n == 0 { return 1; } + if n < 0 { n = -n } + + let mut digits = 0; + while n != 0 { + digits += 1; + n /= 10; } + digits } -fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> PrimitiveType { +fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> DataType { match v { - Variant::Null => PrimitiveType::Null, - Variant::Int8(_) => PrimitiveType::Int { - bits: 8, - signed: true, - }, - Variant::Int16(_) => PrimitiveType::Int { - bits: 16, - signed: true, - }, - Variant::Int32(_) => PrimitiveType::Int { - bits: 32, - signed: true, - }, - Variant::Int64(_) => PrimitiveType::Int { - bits: 64, - signed: true, - }, - Variant::Float(_) => PrimitiveType::Float, - Variant::Double(_) => PrimitiveType::Double, - Variant::Decimal4(d) => PrimitiveType::Decimal { - precision: d.integer().to_string().len() as u8, - scale: d.scale(), - }, - Variant::Decimal8(d) => PrimitiveType::Decimal { - precision: d.integer().to_string().len() as u8, - scale: d.scale(), - }, - Variant::Decimal16(d) => PrimitiveType::Decimal { - precision: d.integer().to_string().len() as u8, - scale: d.scale(), - }, - Variant::BooleanTrue | Variant::BooleanFalse => PrimitiveType::Boolean, - Variant::String(_) | Variant::ShortString(_) => PrimitiveType::String, - Variant::Binary(_) => PrimitiveType::Binary, - Variant::Date(_) => PrimitiveType::Date, - Variant::Time(_) => PrimitiveType::Time, - Variant::TimestampMicros(_) => PrimitiveType::Timestamp { - utc: true, - unit: TimeUnit::Microsecond, - }, - Variant::TimestampNtzMicros(_) => PrimitiveType::Timestamp { - utc: false, - unit: TimeUnit::Microsecond, - }, - Variant::TimestampNanos(_) => PrimitiveType::Timestamp { - utc: true, - unit: TimeUnit::Nanosecond, - }, - Variant::TimestampNtzNanos(_) => PrimitiveType::Timestamp { - utc: false, - unit: TimeUnit::Nanosecond, - }, - Variant::Uuid(_) => PrimitiveType::Uuid, + Variant::Null => DataType::Null, + Variant::Int8(_) => DataType::Int8, + Variant::Int16(_) => DataType::Int16, + Variant::Int32(_) => DataType::Int32, + Variant::Int64(_) => DataType::Int64, + Variant::Float(_) => DataType::Float32, + Variant::Double(_) => DataType::Float64, + Variant::Decimal4(d) => DataType::Decimal32(decimal_precision(d.integer()), d.scale() as i8), + Variant::Decimal8(d) => DataType::Decimal64(decimal_precision(d.integer()), d.scale()as i8), + Variant::Decimal16(d) => DataType::Decimal128(decimal_precision(d.integer()), d.scale() as i8), + Variant::BooleanTrue | Variant::BooleanFalse => DataType::Boolean, + Variant::String(_) | Variant::ShortString(_) | Variant::Uuid(_) => DataType::Utf8, + Variant::Binary(_) => DataType::Binary, + Variant::Date(_) => DataType::Date32, + Variant::Time(_) => DataType::Time64(TimeUnit::Microsecond), + Variant::TimestampMicros(_) => DataType::Timestamp(TimeUnit::Microsecond, Some("utc".into())), + Variant::TimestampNtzMicros(_) => DataType::Timestamp(TimeUnit::Microsecond, None), + Variant::TimestampNanos(_) => DataType::Timestamp(TimeUnit::Nanosecond, Some("utc".into())), + Variant::TimestampNtzNanos(_) => DataType::Timestamp(TimeUnit::Nanosecond, None), _ => unreachable!("Should be only applied to Primitive Variant"), } } // Todo: needs more work on type coercing -fn merge_primitives(a: PrimitiveType, b: PrimitiveType) -> Option { - use PrimitiveType::*; +fn merge_primitives(a: DataType, b: DataType) -> Option { + use DataType::*; match (a, b) { // null handling (Null, x) | (x, Null) => Some(x), // normal case (x, y) if x == y => Some(x), - // numeric widening - (Int { .. }, Double) | (Double, Int { .. }) => Some(Double), - (Int { .. }, Float) | (Float, Int { .. }) => Some(Float), - (Float, Double) | (Double, Float) => Some(Double), - - // decimal rules (simplified) - ( - Decimal { - precision: p1, - scale: s1, - }, - Decimal { - precision: p2, - scale: s2, - }, - ) => Some(Decimal { - precision: p1.max(p2), - scale: s1.max(s2), - }), + // // numeric widening + // (Int { .. }, Double) | (Double, Int { .. }) => Some(Double), + // (Int { .. }, Float) | (Float, Int { .. }) => Some(Float), + // (Float, Double) | (Double, Float) => Some(Double), + + // // decimal rules (simplified) + // ( + // Decimal { + // precision: p1, + // scale: s1, + // }, + // Decimal { + // precision: p2, + // scale: s2, + // }, + // ) => Some(Decimal { + // precision: p1.max(p2), + // scale: s1.max(s2), + // }), _ => None, } @@ -382,7 +323,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "NULL") + assert_eq!(schema, "Null") } #[test] @@ -396,7 +337,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "INT(32, SIGNED)") + assert_eq!(schema, "Int32") } #[test] @@ -410,7 +351,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "DATE") + assert_eq!(schema, "Date32") } #[test] @@ -425,7 +366,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "TIMESTAMP(isAdjustedToUTC=true, Microsecond)") + assert_eq!(schema, "Timestamp(µs, \"utc\")") } #[test] @@ -439,7 +380,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "DECIMAL(4, 1)") + assert_eq!(schema, "Decimal32(4, 1)") } #[test] @@ -453,7 +394,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "FLOAT") + assert_eq!(schema, "Float32") } #[test] @@ -467,7 +408,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "DOUBLE") + assert_eq!(schema, "Float64") } #[test] @@ -481,7 +422,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "BOOLEAN") + assert_eq!(schema, "Boolean") } #[test] @@ -495,7 +436,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "BINARY") + assert_eq!(schema, "Binary") } #[test] @@ -509,7 +450,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "STRING") + assert_eq!(schema, "Utf8") } #[test] @@ -523,7 +464,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "TIME") + assert_eq!(schema, "Time64(µs)") } #[test] @@ -540,7 +481,7 @@ mod tests { }; assert_eq!( schema, - "OBJECT, key: INT(8, SIGNED)>" + "OBJECT, key: Int8>" ) } @@ -573,7 +514,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "OBJECT>") + assert_eq!(schema, "OBJECT>") } #[test] @@ -590,6 +531,6 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "OBJECT") + assert_eq!(schema, "OBJECT") } } From 7a2188fc17bdf1de79412ee8680103b9b8e2c7a7 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Fri, 19 Dec 2025 17:36:57 -0500 Subject: [PATCH 05/32] Fully getting rid of custom structs/enums --- src/variant_schema.rs | 177 +++++++++++++++++++++--------------------- 1 file changed, 87 insertions(+), 90 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index c09c44b..75c422c 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -1,6 +1,7 @@ -use std::{collections::BTreeMap}; +use std::{ops::Deref, sync::Arc}; + use arrow::array::AsArray; -use arrow_schema::{DataType, TimeUnit}; +use arrow_schema::{DataType, Field, Fields, TimeUnit}; use datafusion::{ common::exec_err, error::{DataFusionError, Result}, @@ -8,7 +9,7 @@ use datafusion::{ scalar::ScalarValue, }; use parquet_variant::Variant; -use parquet_variant_compute::VariantArray; +use parquet_variant_compute::{VariantArray, VariantType}; #[derive(Debug, Hash, PartialEq, Eq)] pub struct VariantSchemaUDF { @@ -29,6 +30,7 @@ impl Default for VariantSchemaUDF { /// Primitive -> Just the corresponding SQL data type based on Variant type /// Array -> Should be an > given that type is the same, if not > /// Object -> Each object field should keep track of its type > +/// Variant /// /// For ColumnarValue::Array we get the type for each individual Variant value and compare it with the rest. /// - If one of the values is type => we can early terminate and call everything @@ -49,64 +51,42 @@ impl Default for VariantSchemaUDF { /// > is derived as a VARIANT. For example, INT and DOUBLE become DOUBLE, while TIMESTAMP and STRING become VARIANT." \ /// > https://docs.databricks.com/gcp/en/sql/language-manual/functions/schema_of_variant_agg /// -#[derive(Debug, PartialEq, Eq, Clone)] -enum VariantSchema { - Primitive(DataType), - Array(Box), - Object(BTreeMap), - Variant, -} - /// This function extracts the schema from a single Variant scalar -fn schema_from_variant(v: &Variant) -> VariantSchema { +fn schema_from_variant(v: &Variant) -> DataType { match v { Variant::Object(obj) => { let fields = obj .iter() - .map(|(k, v)| (k.to_string(), schema_from_variant(&v))) + .map(|(k, v)| Field::new(k.to_string(), schema_from_variant(&v), true)) .collect(); - VariantSchema::Object(fields) + DataType::Struct(fields) } Variant::List(list) => { let inner = list .iter() - .map(|v| schema_from_variant(&v)) - .reduce(merge_variant_schema) - .unwrap_or(VariantSchema::Variant); + .map(|v| Field::new("", schema_from_variant(&v), true)) + .reduce(merge_fields) + .unwrap_or( + Field::new("item", DataType::Binary, true).with_extension_type(VariantType), + ); - VariantSchema::Array(Box::new(inner)) + DataType::List(Arc::new(inner)) } // primitives - _ => VariantSchema::Primitive(primitive_from_variant(v)), - } -} - -fn schema_to_string(schema: &VariantSchema) -> String { - match schema { - VariantSchema::Primitive(s) => format!("{s}"), - - VariantSchema::Variant => "VARIANT".to_string(), - - VariantSchema::Array(inner) => { - format!("ARRAY<{}>", schema_to_string(inner)) - } - - VariantSchema::Object(fields) => { - let parts: Vec = fields - .iter() - .map(|(k, v)| format!("{k}: {}", schema_to_string(v))) - .collect(); - format!("OBJECT<{}>", parts.join(", ")) - } + _ => primitive_from_variant(v), } } fn decimal_precision>(val: T) -> u8 { let mut n = val.into(); - if n == 0 { return 1; } - if n < 0 { n = -n } + if n == 0 { + return 1; + } + if n < 0 { + n = -n + } let mut digits = 0; while n != 0 { @@ -125,15 +105,23 @@ fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> DataType { Variant::Int64(_) => DataType::Int64, Variant::Float(_) => DataType::Float32, Variant::Double(_) => DataType::Float64, - Variant::Decimal4(d) => DataType::Decimal32(decimal_precision(d.integer()), d.scale() as i8), - Variant::Decimal8(d) => DataType::Decimal64(decimal_precision(d.integer()), d.scale()as i8), - Variant::Decimal16(d) => DataType::Decimal128(decimal_precision(d.integer()), d.scale() as i8), + Variant::Decimal4(d) => { + DataType::Decimal32(decimal_precision(d.integer()), d.scale() as i8) + } + Variant::Decimal8(d) => { + DataType::Decimal64(decimal_precision(d.integer()), d.scale() as i8) + } + Variant::Decimal16(d) => { + DataType::Decimal128(decimal_precision(d.integer()), d.scale() as i8) + } Variant::BooleanTrue | Variant::BooleanFalse => DataType::Boolean, Variant::String(_) | Variant::ShortString(_) | Variant::Uuid(_) => DataType::Utf8, Variant::Binary(_) => DataType::Binary, Variant::Date(_) => DataType::Date32, Variant::Time(_) => DataType::Time64(TimeUnit::Microsecond), - Variant::TimestampMicros(_) => DataType::Timestamp(TimeUnit::Microsecond, Some("utc".into())), + Variant::TimestampMicros(_) => { + DataType::Timestamp(TimeUnit::Microsecond, Some("utc".into())) + } Variant::TimestampNtzMicros(_) => DataType::Timestamp(TimeUnit::Microsecond, None), Variant::TimestampNanos(_) => DataType::Timestamp(TimeUnit::Nanosecond, Some("utc".into())), Variant::TimestampNtzNanos(_) => DataType::Timestamp(TimeUnit::Nanosecond, None), @@ -142,18 +130,23 @@ fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> DataType { } // Todo: needs more work on type coercing -fn merge_primitives(a: DataType, b: DataType) -> Option { +fn merge_datatypes(a: DataType, b: DataType) -> DataType { use DataType::*; match (a, b) { // null handling - (Null, x) | (x, Null) => Some(x), + (Null, x) | (x, Null) => x.clone(), // normal case - (x, y) if x == y => Some(x), - // // numeric widening - // (Int { .. }, Double) | (Double, Int { .. }) => Some(Double), - // (Int { .. }, Float) | (Float, Int { .. }) => Some(Float), - // (Float, Double) | (Double, Float) => Some(Double), + (x, y) if x == y => x.clone(), + // numeric widening + // docs.databricks.com/aws/en/sql/language-manual/sql-ref-datatype-rules#type-precedence-list + // For least common type resolution FLOAT is skipped to avoid loss of precision. + (Int8 | Int16 | Int32 | Int64 | Float32, Float64) + | (Float64, Int8 | Int16 | Int32 | Int64 | Float32) => Float64, + (Int8 | Int16 | Int32, Int64) | (Int64, Int8 | Int16 | Int32) => Int64, + (Int8 | Int16, Int32) | (Int32, Int8 | Int16) => Int32, + + (Date32, Timestamp(tu, tz)) | (Timestamp(tu, tz), Date32) => Timestamp(tu, tz), // // decimal rules (simplified) // ( @@ -169,33 +162,43 @@ fn merge_primitives(a: DataType, b: DataType) -> Option { // precision: p1.max(p2), // scale: s1.max(s2), // }), - - _ => None, - } -} - -fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { - use VariantSchema::*; - match (a, b) { - (Variant, _) | (_, Variant) => Variant, - - (Primitive(p1), Primitive(p2)) => { - merge_primitives(p1, p2).map(Primitive).unwrap_or(Variant) + (List(a), List(b)) => { + DataType::List(Arc::new(merge_fields(a.deref().clone(), b.deref().clone()))) } - (Array(a), Array(b)) => Array(Box::new(merge_variant_schema(*a, *b))), + (Struct(a), Struct(b)) => { + // Step 1: extract Fields into Vec + let mut merged_fields: Vec = a + .iter() // iterates over &Arc + .map(|f| f.as_ref().clone()) // clone Field out of Arc + .collect(); - (Object(mut a), Object(b)) => { - for (k, v_b) in b { - a.entry(k) - .and_modify(|v_a| *v_a = merge_variant_schema(v_a.clone(), v_b.clone())) - .or_insert(v_b); + // Step 2: merge b_fields + for b_field in b.iter() { + if let Some(existing) = merged_fields + .iter_mut() + .find(|f| f.name() == b_field.name()) + { + *existing = merge_fields(existing.clone(), b_field.deref().clone()); + } else { + merged_fields.push((**b_field).clone()); // clone b_field Field + } } - Object(a) + + // Step 3: build new Struct + DataType::Struct(Fields::from(merged_fields)) } + _ => unreachable!("the cases above should cover everything"), + } +} - _ => Variant, +fn merge_fields(a: Field, b: Field) -> Field { + if a.extension_type_name() == Some("VARIANT") && b.extension_type_name() == Some("VARIANT") { + return Field::new("merged_field", DataType::Binary, true).with_extension_type(VariantType); } + let merged_type = merge_datatypes(a.data_type().clone(), b.data_type().clone()); + + Field::new(a.name(), merged_type, a.is_nullable() || b.is_nullable()) } fn infer_variant_schema(variant: &ColumnarValue) -> Result { @@ -207,12 +210,11 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { let variant_array = VariantArray::try_new(struct_array.as_ref())?; let v = variant_array.value(0); - let schema = schema_from_variant(&v); - let schema_str = schema_to_string(&schema); + let data_type = schema_from_variant(&v); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - schema_str, - )))) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(format!( + "{data_type}" + ))))) } ColumnarValue::Array(arr) => { let variant_array = @@ -222,12 +224,11 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { .iter() .flatten() .map(|v| schema_from_variant(&v)) - .reduce(merge_variant_schema) - .unwrap_or(VariantSchema::Variant); + .reduce(merge_datatypes); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - schema_to_string(&final_schema), - )))) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(format!( + "{final_schema:?}" + ))))) } } } @@ -302,9 +303,7 @@ mod tests { .with_extension_type(VariantType), ); ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new( - struct_array, - ))], + args: vec![ColumnarValue::Array(Arc::new(struct_array))], arg_fields: vec![arg_field], number_rows: Default::default(), return_field, @@ -479,10 +478,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!( - schema, - "OBJECT, key: Int8>" - ) + assert_eq!(schema, "OBJECT, key: Int8>") } #[test] @@ -514,7 +510,8 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - assert_eq!(schema, "OBJECT>") + // assert_eq!(schema, "OBJECT>") + assert_eq!(schema, "Some(Struct([Field { name: \"foo\", data_type: Utf8, nullable: true }, Field { name: \"wing\", data_type: Struct([Field { name: \"ding\", data_type: Utf8, nullable: true }]), nullable: true }]))") } #[test] From 09e0b922d40b4b69275fd8312db272243bddd264 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Thu, 18 Dec 2025 16:11:25 -0500 Subject: [PATCH 06/32] replaced primitive enum with arrow DataType --- src/variant_schema.rs | 88 ++++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 43 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 75c422c..6872dad 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -1,5 +1,6 @@ use std::{ops::Deref, sync::Arc}; +use std::{collections::BTreeMap}; use arrow::array::AsArray; use arrow_schema::{DataType, Field, Fields, TimeUnit}; use datafusion::{ @@ -51,6 +52,14 @@ impl Default for VariantSchemaUDF { /// > is derived as a VARIANT. For example, INT and DOUBLE become DOUBLE, while TIMESTAMP and STRING become VARIANT." \ /// > https://docs.databricks.com/gcp/en/sql/language-manual/functions/schema_of_variant_agg /// +#[derive(Debug, PartialEq, Eq, Clone)] +enum VariantSchema { + Primitive(DataType), + Array(Box), + Object(BTreeMap), + Variant, +} + /// This function extracts the schema from a single Variant scalar fn schema_from_variant(v: &Variant) -> DataType { match v { @@ -84,16 +93,36 @@ fn decimal_precision>(val: T) -> u8 { if n == 0 { return 1; } - if n < 0 { - n = -n + if n < 0 { + n = -n + } + + let mut digits = 0; + while n != 0 { + digits += 1; + n /= 10; + } + digits } - let mut digits = 0; - while n != 0 { - digits += 1; - n /= 10; +fn schema_to_string(schema: &VariantSchema) -> String { + match schema { + VariantSchema::Primitive(s) => format!("{s}"), + + VariantSchema::Variant => "VARIANT".to_string(), + + VariantSchema::Array(inner) => { + format!("ARRAY<{}>", schema_to_string(inner)) + } + + VariantSchema::Object(fields) => { + let parts: Vec = fields + .iter() + .map(|(k, v)| format!("{k}: {}", schema_to_string(v))) + .collect(); + format!("OBJECT<{}>", parts.join(", ")) + } } - digits } fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> DataType { @@ -130,23 +159,23 @@ fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> DataType { } // Todo: needs more work on type coercing -fn merge_datatypes(a: DataType, b: DataType) -> DataType { +fn merge_primitives(a: DataType, b: DataType) -> Option { use DataType::*; match (a, b) { // null handling - (Null, x) | (x, Null) => x.clone(), + (Null, x) | (x, Null) => Some(x), // normal case - (x, y) if x == y => x.clone(), + (x, y) if x == y => Some(x), // numeric widening // docs.databricks.com/aws/en/sql/language-manual/sql-ref-datatype-rules#type-precedence-list // For least common type resolution FLOAT is skipped to avoid loss of precision. (Int8 | Int16 | Int32 | Int64 | Float32, Float64) - | (Float64, Int8 | Int16 | Int32 | Int64 | Float32) => Float64, - (Int8 | Int16 | Int32, Int64) | (Int64, Int8 | Int16 | Int32) => Int64, - (Int8 | Int16, Int32) | (Int32, Int8 | Int16) => Int32, + | (Float64, Int8 | Int16 | Int32 | Int64 | Float32) => Some(Float64), + (Int8 | Int16 | Int32, Int64) | (Int64, Int8 | Int16 | Int32) => Some(Int64), + (Int8 | Int16, Int32) | (Int32, Int8 | Int16) => Some(Int32), - (Date32, Timestamp(tu, tz)) | (Timestamp(tu, tz), Date32) => Timestamp(tu, tz), + (Date32, Timestamp(tu, tz)) | (Timestamp(tu, tz), Date32) => Some(Timestamp(tu, tz)), // // decimal rules (simplified) // ( @@ -162,33 +191,7 @@ fn merge_datatypes(a: DataType, b: DataType) -> DataType { // precision: p1.max(p2), // scale: s1.max(s2), // }), - (List(a), List(b)) => { - DataType::List(Arc::new(merge_fields(a.deref().clone(), b.deref().clone()))) - } - - (Struct(a), Struct(b)) => { - // Step 1: extract Fields into Vec - let mut merged_fields: Vec = a - .iter() // iterates over &Arc - .map(|f| f.as_ref().clone()) // clone Field out of Arc - .collect(); - - // Step 2: merge b_fields - for b_field in b.iter() { - if let Some(existing) = merged_fields - .iter_mut() - .find(|f| f.name() == b_field.name()) - { - *existing = merge_fields(existing.clone(), b_field.deref().clone()); - } else { - merged_fields.push((**b_field).clone()); // clone b_field Field - } - } - - // Step 3: build new Struct - DataType::Struct(Fields::from(merged_fields)) - } - _ => unreachable!("the cases above should cover everything"), + _ => unreachable!("Not primitive types {}, {}", a,b), } } @@ -510,8 +513,7 @@ mod tests { let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { panic!() }; - // assert_eq!(schema, "OBJECT>") - assert_eq!(schema, "Some(Struct([Field { name: \"foo\", data_type: Utf8, nullable: true }, Field { name: \"wing\", data_type: Struct([Field { name: \"ding\", data_type: Utf8, nullable: true }]), nullable: true }]))") + assert_eq!(schema, "OBJECT>") } #[test] From 1a10082e69e316b4f23b7a8251b6ce36792fd1cd Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Thu, 18 Dec 2025 15:21:22 -0500 Subject: [PATCH 07/32] sort of works --- src/variant_schema.rs | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 6872dad..0477ec4 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -195,9 +195,27 @@ fn merge_primitives(a: DataType, b: DataType) -> Option { } } -fn merge_fields(a: Field, b: Field) -> Field { - if a.extension_type_name() == Some("VARIANT") && b.extension_type_name() == Some("VARIANT") { - return Field::new("merged_field", DataType::Binary, true).with_extension_type(VariantType); +fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { + use VariantSchema::*; + match (a, b) { + (Variant, _) | (_, Variant) => Variant, + + (Primitive(p1), Primitive(p2)) => { + merge_primitives(p1, p2).map(Primitive).unwrap_or(Variant) + } + + (Array(a), Array(b)) => Array(Box::new(merge_variant_schema(*a, *b))), + + (Object(mut a), Object(b)) => { + for (k, v_b) in b { + a.entry(k) + .and_modify(|v_a| *v_a = merge_variant_schema(v_a.clone(), v_b.clone())) + .or_insert(v_b); + } + Object(a) + } + + _ => Variant, } let merged_type = merge_datatypes(a.data_type().clone(), b.data_type().clone()); From 800c0ba1df8caae0ce86285702e4e51eadd8ce2c Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Wed, 17 Dec 2025 17:03:05 -0500 Subject: [PATCH 08/32] Somewhat working code --- src/variant_schema.rs | 185 ++++++++++++++++++++---------------------- 1 file changed, 89 insertions(+), 96 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 0477ec4..72012e1 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -1,8 +1,5 @@ -use std::{ops::Deref, sync::Arc}; - -use std::{collections::BTreeMap}; -use arrow::array::AsArray; -use arrow_schema::{DataType, Field, Fields, TimeUnit}; +use arrow::{array::AsArray}; +use arrow_schema::{DataType, TimeUnit}; use datafusion::{ common::exec_err, error::{DataFusionError, Result}, @@ -10,7 +7,8 @@ use datafusion::{ scalar::ScalarValue, }; use parquet_variant::Variant; -use parquet_variant_compute::{VariantArray, VariantType}; +use parquet_variant_compute::VariantArray; +use std::collections::BTreeMap; #[derive(Debug, Hash, PartialEq, Eq)] pub struct VariantSchemaUDF { @@ -25,32 +23,35 @@ impl Default for VariantSchemaUDF { } } -/// For schema_from_variant Schema representation -/// there are 4 possible types available depending on the Variant -/// For ColumnarValue::Scalar: -/// Primitive -> Just the corresponding SQL data type based on Variant type -/// Array -> Should be an > given that type is the same, if not > -/// Object -> Each object field should keep track of its type > -/// Variant +/// Schema inference rules for VARIANT values. +/// +/// The inferred schema can be one of four logical forms: +/// - Primitive: a concrete SQL / Arrow data type +/// - Array: ARRAY, where `inner` is the merged element schema +/// - Object: OBJECT, merged field-wise by name +/// - Variant: fallback when no common schema can be determined /// -/// For ColumnarValue::Array we get the type for each individual Variant value and compare it with the rest. -/// - If one of the values is type => we can early terminate and call everything +/// ## Scalar input +/// When the input is a single VARIANT value: +/// - Primitive values map directly to their corresponding data type +/// - Arrays infer a common element schema across all elements +/// - Objects infer schemas per field recursively +/// - Mixed or incompatible types resolve to VARIANT /// -/// For different Variant types we will use different ways of keeping track of types between the rows: -/// - If the outer/inner type differs, we call it and early terminate this level. -/// - Primitive -> create a set to store Primitive types and if the set.len() > 1, we call everything -/// - Array -> array type is flat, so the same implementation as Primitive should work for this Array type -/// - Object -> is the difficult one in this case. Different rows can include or exlude certain fields, and we -/// need to keep track of each field between different rows separately. To keep it the fields sorted we will use -/// a BTree with fields as keys and set of types as values. We will add values to each field's set and if set.len() -/// \> 1, we call this field and we no longer need to keep track of it. -/// +/// ## Array input +/// When the input is an array of VARIANT values: +/// - Each element is inferred independently +/// - Schemas are merged across rows +/// - If any merge step resolves to VARIANT, inference short-circuits /// -/// Later we should also implement Databricks' coersion into a similiar type: -/// > "The schema of each VARIANT value is merged together by field name. When two fields with the same name have -/// > a different type across records, Databricks uses the least common type. When no such type exists, the type -/// > is derived as a VARIANT. For example, INT and DOUBLE become DOUBLE, while TIMESTAMP and STRING become VARIANT." \ -/// > https://docs.databricks.com/gcp/en/sql/language-manual/functions/schema_of_variant_agg +/// ## Merge rules +/// - If outer (or inner) kinds differ, the result is VARIANT +/// - Primitive types are merged using widening / least-common-type rules +/// - Arrays merge by merging their element schemas +/// - Objects merge field-by-field: +/// - Missing fields are allowed +/// - Field schemas are merged independently +/// - A field becomes VARIANT if its values are incompatible /// #[derive(Debug, PartialEq, Eq, Clone)] enum VariantSchema { @@ -61,70 +62,49 @@ enum VariantSchema { } /// This function extracts the schema from a single Variant scalar -fn schema_from_variant(v: &Variant) -> DataType { +fn schema_from_variant(v: &Variant) -> VariantSchema { match v { Variant::Object(obj) => { let fields = obj .iter() - .map(|(k, v)| Field::new(k.to_string(), schema_from_variant(&v), true)) + .map(|(k, v)| (k.to_string(), schema_from_variant(&v))) .collect(); - DataType::Struct(fields) + VariantSchema::Object(fields) } - Variant::List(list) => { let inner = list .iter() - .map(|v| Field::new("", schema_from_variant(&v), true)) - .reduce(merge_fields) - .unwrap_or( - Field::new("item", DataType::Binary, true).with_extension_type(VariantType), - ); + .map(|v| schema_from_variant(&v)) + .reduce(merge_variant_schema) + .unwrap_or(VariantSchema::Variant); - DataType::List(Arc::new(inner)) + VariantSchema::Array(Box::new(inner)) } - // primitives - _ => primitive_from_variant(v), + _ => VariantSchema::Primitive(primitive_from_variant(v)), } } +/// This helper function is used to calculate decimal precision +/// for [primitive_from_variant] decimal Variants conversion fn decimal_precision>(val: T) -> u8 { let mut n = val.into(); if n == 0 { return 1; } - if n < 0 { - n = -n - } - - let mut digits = 0; - while n != 0 { - digits += 1; - n /= 10; - } - digits + if n < 0 { + n = -n } -fn schema_to_string(schema: &VariantSchema) -> String { - match schema { - VariantSchema::Primitive(s) => format!("{s}"), - - VariantSchema::Variant => "VARIANT".to_string(), - - VariantSchema::Array(inner) => { - format!("ARRAY<{}>", schema_to_string(inner)) - } - - VariantSchema::Object(fields) => { - let parts: Vec = fields - .iter() - .map(|(k, v)| format!("{k}: {}", schema_to_string(v))) - .collect(); - format!("OBJECT<{}>", parts.join(", ")) - } + let mut digits = 0; + while n != 0 { + digits += 1; + n /= 10; } + digits } +/// This function is used to extract datatype from a primitive Variant fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> DataType { match v { Variant::Null => DataType::Null, @@ -154,11 +134,16 @@ fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> DataType { Variant::TimestampNtzMicros(_) => DataType::Timestamp(TimeUnit::Microsecond, None), Variant::TimestampNanos(_) => DataType::Timestamp(TimeUnit::Nanosecond, Some("utc".into())), Variant::TimestampNtzNanos(_) => DataType::Timestamp(TimeUnit::Nanosecond, None), - _ => unreachable!("Should be only applied to Primitive Variant"), + _ => unreachable!("Should be only applied to Primitive Variant, not Object or List"), } } -// Todo: needs more work on type coercing +/// This function is used to merge types between schemas +/// and coerce them into a common type when possible if types +/// are different +/// +/// Todo: needs more work on type coercing +/// - add decimal coercion rules fn merge_primitives(a: DataType, b: DataType) -> Option { use DataType::*; @@ -174,29 +159,17 @@ fn merge_primitives(a: DataType, b: DataType) -> Option { | (Float64, Int8 | Int16 | Int32 | Int64 | Float32) => Some(Float64), (Int8 | Int16 | Int32, Int64) | (Int64, Int8 | Int16 | Int32) => Some(Int64), (Int8 | Int16, Int32) | (Int32, Int8 | Int16) => Some(Int32), - (Date32, Timestamp(tu, tz)) | (Timestamp(tu, tz), Date32) => Some(Timestamp(tu, tz)), - // // decimal rules (simplified) - // ( - // Decimal { - // precision: p1, - // scale: s1, - // }, - // Decimal { - // precision: p2, - // scale: s2, - // }, - // ) => Some(Decimal { - // precision: p1.max(p2), - // scale: s1.max(s2), - // }), - _ => unreachable!("Not primitive types {}, {}", a,b), + _ => None, } } +/// Merges two inferred Variant schemas into a common schema. +/// Returns VARIANT if no common schema can be determined. fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { use VariantSchema::*; + match (a, b) { (Variant, _) | (_, Variant) => Variant, @@ -217,11 +190,30 @@ fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { _ => Variant, } - let merged_type = merge_datatypes(a.data_type().clone(), b.data_type().clone()); +} + +/// Prints schema in a presentable manner +fn print_schema(schema: &VariantSchema) -> String { + match schema { + VariantSchema::Primitive(s) => format!("{s}"), - Field::new(a.name(), merged_type, a.is_nullable() || b.is_nullable()) + VariantSchema::Variant => "VARIANT".to_string(), + + VariantSchema::Array(inner) => { + format!("ARRAY<{}>", print_schema(inner)) + } + + VariantSchema::Object(fields) => { + let parts: Vec = fields + .iter() + .map(|(k, v)| format!("{k}: {}", print_schema(v))) + .collect(); + format!("OBJECT<{}>", parts.join(", ")) + } + } } +/// Final function used to retrieve schema from a single Variant or VariantArray fn infer_variant_schema(variant: &ColumnarValue) -> Result { match variant { ColumnarValue::Scalar(scalar) => { @@ -231,11 +223,11 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { let variant_array = VariantArray::try_new(struct_array.as_ref())?; let v = variant_array.value(0); - let data_type = schema_from_variant(&v); + let schema = schema_from_variant(&v); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(format!( - "{data_type}" - ))))) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + print_schema(&schema), + )))) } ColumnarValue::Array(arr) => { let variant_array = @@ -245,11 +237,12 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { .iter() .flatten() .map(|v| schema_from_variant(&v)) - .reduce(merge_datatypes); + .reduce(merge_variant_schema) + .unwrap_or(VariantSchema::Variant); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(format!( - "{final_schema:?}" - ))))) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + print_schema(&final_schema), + )))) } } } From fdf831eb83bba395f5fbf0f3f747e7dc02671d66 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 22 Dec 2025 12:54:17 -0500 Subject: [PATCH 09/32] cargo fmt --- src/variant_schema.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 72012e1..cd8e619 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -1,4 +1,4 @@ -use arrow::{array::AsArray}; +use arrow::array::AsArray; use arrow_schema::{DataType, TimeUnit}; use datafusion::{ common::exec_err, @@ -138,10 +138,10 @@ fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> DataType { } } -/// This function is used to merge types between schemas +/// This function is used to merge types between schemas /// and coerce them into a common type when possible if types /// are different -/// +/// /// Todo: needs more work on type coercing /// - add decimal coercion rules fn merge_primitives(a: DataType, b: DataType) -> Option { From 3ee27ca62380306510d8f12ad58189fec268d56b Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Sun, 28 Dec 2025 23:34:37 -0500 Subject: [PATCH 10/32] add sqllogictests --- tests/sqllogictests.rs | 3 ++- tests/test_files/variant_schema.slt | 34 +++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 tests/test_files/variant_schema.slt diff --git a/tests/sqllogictests.rs b/tests/sqllogictests.rs index 0ae1250..93801f8 100644 --- a/tests/sqllogictests.rs +++ b/tests/sqllogictests.rs @@ -3,7 +3,7 @@ use datafusion_sqllogictest::{DataFusion, TestContext}; use datafusion_variant::{ CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetUdf, VariantListConstruct, VariantListInsert, VariantObjectConstruct, VariantObjectInsert, VariantPretty, - VariantToJsonUdf, + VariantSchemaUDF, VariantToJsonUdf, }; use indicatif::ProgressBar; use sqllogictest::strict_column_validator; @@ -54,6 +54,7 @@ async fn run_sqllogictests() -> Result<(), Box> { ctx.register_udf(ScalarUDF::new_from_impl(VariantListConstruct::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantListInsert::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectInsert::default())); + ctx.register_udf(ScalarUDF::new_from_impl(VariantSchemaUDF::default())); let pb = ProgressBar::new(24); diff --git a/tests/test_files/variant_schema.slt b/tests/test_files/variant_schema.slt new file mode 100644 index 0000000..449704e --- /dev/null +++ b/tests/test_files/variant_schema.slt @@ -0,0 +1,34 @@ +# tests the variant_schema udf +# this function takes a Scalar Variant +# or a Variant Array and extracts it's SQL schema + +# Simple example with a Scalar value +query T +SELECT variant_schema(json_to_variant('{"key": 123, "data": [4, 5]}')) +---- +OBJECT, key: Int8> + + +# Conflicting element types in array +query T +SELECT variant_schema(json_to_variant('{"data": [{"a":"a"}, 5]}')) +---- +OBJECT> + +# A typed literal +query T +SELECT variant_schema(json_to_variant(123.4)) +---- +Float64 + +# Variant Arrays +statement ok +CREATE TABLE data as VALUES +(json_to_variant('{"foo": "bar", "wing": {"ding": "dong"}}')), +(json_to_variant('{"wing": 123}')), +(json_to_variant('{"wing": 123}')); + +query T +SELECT variant_schema_agg(column1) from data; +---- +OBJECT \ No newline at end of file From 3ff9f1ab753bb55d0a9411680d36898f5a845583 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Sun, 28 Dec 2025 23:56:12 -0500 Subject: [PATCH 11/32] Splitting the function in two --- src/lib.rs | 1 + src/variant_schema.rs | 106 ++++----------------- src/variant_schema_agg.rs | 195 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 215 insertions(+), 87 deletions(-) create mode 100644 src/variant_schema_agg.rs diff --git a/src/lib.rs b/src/lib.rs index 853eef0..0503ad6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ mod variant_object_construct; mod variant_object_insert; mod variant_pretty; mod variant_schema; +mod variant_schema_agg; mod variant_to_json; pub use cast_to_variant::*; diff --git a/src/variant_schema.rs b/src/variant_schema.rs index cd8e619..29e7480 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -1,4 +1,3 @@ -use arrow::array::AsArray; use arrow_schema::{DataType, TimeUnit}; use datafusion::{ common::exec_err, @@ -54,7 +53,7 @@ impl Default for VariantSchemaUDF { /// - A field becomes VARIANT if its values are incompatible /// #[derive(Debug, PartialEq, Eq, Clone)] -enum VariantSchema { +pub enum VariantSchema { Primitive(DataType), Array(Box), Object(BTreeMap), @@ -62,7 +61,7 @@ enum VariantSchema { } /// This function extracts the schema from a single Variant scalar -fn schema_from_variant(v: &Variant) -> VariantSchema { +pub fn schema_from_variant(v: &Variant) -> VariantSchema { match v { Variant::Object(obj) => { let fields = obj @@ -167,7 +166,7 @@ fn merge_primitives(a: DataType, b: DataType) -> Option { /// Merges two inferred Variant schemas into a common schema. /// Returns VARIANT if no common schema can be determined. -fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { +pub fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { use VariantSchema::*; match (a, b) { @@ -193,7 +192,7 @@ fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { } /// Prints schema in a presentable manner -fn print_schema(schema: &VariantSchema) -> String { +pub fn print_schema(schema: &VariantSchema) -> String { match schema { VariantSchema::Primitive(s) => format!("{s}"), @@ -213,37 +212,22 @@ fn print_schema(schema: &VariantSchema) -> String { } } -/// Final function used to retrieve schema from a single Variant or VariantArray +/// Final function used to retrieve the schema from a single Variant fn infer_variant_schema(variant: &ColumnarValue) -> Result { - match variant { - ColumnarValue::Scalar(scalar) => { - let ScalarValue::Struct(struct_array) = scalar else { - return exec_err!("Unsupported data type: {}", scalar.data_type()); - }; - let variant_array = VariantArray::try_new(struct_array.as_ref())?; - let v = variant_array.value(0); - - let schema = schema_from_variant(&v); - - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - print_schema(&schema), - )))) - } - ColumnarValue::Array(arr) => { - let variant_array = - VariantArray::try_new(arr.as_struct()).expect("Expect VariantArray"); + if let ColumnarValue::Scalar(scalar) = variant { + let ScalarValue::Struct(struct_array) = scalar else { + return exec_err!("Unsupported data type: {}", scalar.data_type()); + }; + let variant_array = VariantArray::try_new(struct_array.as_ref())?; + let v = variant_array.value(0); - let final_schema = variant_array - .iter() - .flatten() - .map(|v| schema_from_variant(&v)) - .reduce(merge_variant_schema) - .unwrap_or(VariantSchema::Variant); + let schema = schema_from_variant(&v); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - print_schema(&final_schema), - )))) - } + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + print_schema(&schema), + )))) + } else { + exec_err!("Expected a ScalarValue, got: {:?}", variant) } } @@ -261,7 +245,7 @@ impl ScalarUDFImpl for VariantSchemaUDF { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn invoke_with_args( @@ -288,10 +272,7 @@ mod tests { use parquet_variant_compute::{VariantArray, VariantType}; use std::sync::Arc; - use crate::{ - VariantSchemaUDF, - shared::{build_variant_array_from_json, build_variant_array_from_json_array}, - }; + use crate::{VariantSchemaUDF, shared::build_variant_array_from_json}; fn build_scalar_udf_args(struct_array: StructArray) -> ScalarFunctionArgs { let return_field = Arc::new(Field::new("result", DataType::Utf8View, true)); @@ -310,21 +291,6 @@ mod tests { } } - fn build_array_udf_args(struct_array: StructArray) -> ScalarFunctionArgs { - let return_field = Arc::new(Field::new("result", DataType::Utf8View, true)); - let arg_field = Arc::new( - Field::new("input", DataType::Struct(Fields::empty()), true) - .with_extension_type(VariantType), - ); - ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(struct_array))], - arg_fields: vec![arg_field], - number_rows: Default::default(), - return_field, - config_options: Default::default(), - } - } - #[test] fn test_get_single_typed_null_variant_schema() { let udf = VariantSchemaUDF::default(); @@ -509,38 +475,4 @@ mod tests { }; assert_eq!(schema, "OBJECT>") } - - #[test] - fn test_get_array_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant_array = build_variant_array_from_json_array(&[ - Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), - None, - Some(serde_json::json!({"wing": {"ding": "man"}})), - ]); - let struct_array = variant_array.into_inner(); - let args = build_array_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "OBJECT>") - } - - #[test] - fn test_get_array_variant_conflicting_schema() { - let udf = VariantSchemaUDF::default(); - let variant_array = build_variant_array_from_json_array(&[ - Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), - // None, - Some(serde_json::json!({"wing": 123})), - ]); - let struct_array = variant_array.into_inner(); - let args = build_array_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "OBJECT") - } } diff --git a/src/variant_schema_agg.rs b/src/variant_schema_agg.rs new file mode 100644 index 0000000..e834211 --- /dev/null +++ b/src/variant_schema_agg.rs @@ -0,0 +1,195 @@ +use arrow::array::AsArray; +use arrow_schema::DataType; +use datafusion::{ + error::Result, + logical_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, function::AccumulatorArgs}, + scalar::ScalarValue, +}; +use parquet_variant_compute::VariantArray; + +use crate::{VariantSchema, merge_variant_schema, print_schema, schema_from_variant}; + +#[derive(Debug, Hash, PartialEq, Eq)] +struct VariantSchemaAggUDAF { + signature: Signature, +} + +impl Default for VariantSchemaAggUDAF { + fn default() -> Self { + Self { + signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for VariantSchemaAggUDAF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "variant_schema_agg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8View) + } + + fn accumulator( + &self, + acc_args: datafusion::logical_expr::function::AccumulatorArgs, + ) -> Result> { + Ok(Box::new(VariantSchemaAccumulator::new(acc_args))) + } +} + +#[derive(Debug)] +/// An accumulator to compute and store merged VariantSchame +pub struct VariantSchemaAccumulator { + schema: VariantSchema, // This will store the current inferred schema +} + +impl VariantSchemaAccumulator { + fn new(_acc_args: AccumulatorArgs) -> Self { + // Initialize with Variant as the starting schema + Self { + schema: VariantSchema::Primitive(DataType::Null), + } + } +} + +impl Accumulator for VariantSchemaAccumulator { + fn state(&mut self) -> Result> { + // Return the current state (the inferred schema) + Ok(vec![ScalarValue::Utf8View(Some(print_schema( + &self.schema, + )))]) + } + + fn evaluate(&mut self) -> Result { + // Return the schema as a Utf8 representation + Ok(ScalarValue::Utf8View(Some(print_schema(&self.schema)))) + } + + fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> Result<()> { + // We're assuming the input is an array of variants + for value in values { + // Ensure we are dealing with VariantArray and extract the variant values + let variant_array = VariantArray::try_new(value.as_struct())?; + for variant in variant_array.iter().flatten() { + let new_schema = schema_from_variant(&variant); + // Merge the new schema with the current schema + self.schema = merge_variant_schema(self.schema.clone(), new_schema); + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> Result<()> { + // Merge schemas from other states (batches) + for state in states { + let variant_array = VariantArray::try_new(state.as_struct())?; + for variant in variant_array.iter().flatten() { + let new_schema = schema_from_variant(&variant); + self.schema = merge_variant_schema(self.schema.clone(), new_schema); + } + } + Ok(()) + } + + fn size(&self) -> usize { + // The size is essentially the number of variants processed, if needed + 1 // This could be expanded to return a more useful size + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::{ + logical_expr::{ + Accumulator, + function::AccumulatorArgs, + }, + scalar::ScalarValue, + }; + use parquet_variant_compute::VariantType; + + use crate::{ + shared::build_variant_array_from_json_array, variant_schema_agg::VariantSchemaAccumulator, + }; + + #[test] + fn test_get_agg_variant_schema() { + let b = build_variant_array_from_json_array(&[ + Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), + None, + Some(serde_json::json!({"wing": {"ding": "man"}})), + ]); + let b: ArrayRef = Arc::new(b.into_inner()); + + let schema = Schema::new(vec![Field::new("b", DataType::Struct(_), true).with_extension_type(VariantType)]); + + let acc_args = AccumulatorArgs { + return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), + schema: &Schema::empty(), + ignore_nulls: false, + order_bys:, + is_reversed: false, + name: "variant_schema_agg", + is_distinct: false, + exprs:, + }; + + let mut schema = VariantSchemaAccumulator::new(acc_args); + schema.update_batch(&[Arc::clone(&b)]).unwrap(); + let final_schema = schema.evaluate().unwrap(); + assert_eq!( + final_schema, + ScalarValue::Utf8View(Some( + "OBJECT>".to_string() + )) + ) + } + + // #[test] + // fn test_get_array_variant_schema() { + // let udaf = VariantSchemaAggUDAF::default(); + // let variant_array = build_variant_array_from_json_array(&[ + // Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), + // None, + // Some(serde_json::json!({"wing": {"ding": "man"}})), + // ]); + // let struct_array = variant_array.into_inner(); + // let args = build_array_udf_args(struct_array); + // let result = udaf.accumulator(acc_args) + // let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + // panic!() + // }; + // assert_eq!(schema, "OBJECT>") + // } + + // #[test] + // fn test_get_array_variant_conflicting_schema() { + // let udf = VariantSchemaAggUDAF::default(); + // let variant_array = build_variant_array_from_json_array(&[ + // Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), + // // None, + // Some(serde_json::json!({"wing": 123})), + // ]); + // let struct_array = variant_array.into_inner(); + // let args = build_array_udf_args(struct_array); + // let result = udf.invoke_with_args(args).unwrap(); + // let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { + // panic!() + // }; + // assert_eq!(schema, "OBJECT") + // } +} From 539fe1bada75a1b4b969cc007b480a0075918c47 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 29 Dec 2025 13:13:08 -0500 Subject: [PATCH 12/32] Split functions work + sqllogictests --- src/lib.rs | 1 + src/variant_schema.rs | 7 +- src/variant_schema_agg.rs | 119 ++++++++++++++---------- tests/sqllogictests.rs | 10 +- tests/test_files/variant_schema.slt | 62 +++++++++--- tests/test_files/variant_schema_agg.slt | 114 +++++++++++++++++++++++ 6 files changed, 243 insertions(+), 70 deletions(-) create mode 100644 tests/test_files/variant_schema_agg.slt diff --git a/src/lib.rs b/src/lib.rs index 0503ad6..840523f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,4 +25,5 @@ pub use variant_object_construct::*; pub use variant_object_insert::*; pub use variant_pretty::*; pub use variant_schema::*; +pub use variant_schema_agg::*; pub use variant_to_json::*; diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 29e7480..804d390 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -76,7 +76,7 @@ pub fn schema_from_variant(v: &Variant) -> VariantSchema { .iter() .map(|v| schema_from_variant(&v)) .reduce(merge_variant_schema) - .unwrap_or(VariantSchema::Variant); + .unwrap_or(VariantSchema::Primitive(DataType::Null)); VariantSchema::Array(Box::new(inner)) } @@ -147,9 +147,6 @@ fn merge_primitives(a: DataType, b: DataType) -> Option { use DataType::*; match (a, b) { - // null handling - (Null, x) | (x, Null) => Some(x), - // normal case (x, y) if x == y => Some(x), // numeric widening // docs.databricks.com/aws/en/sql/language-manual/sql-ref-datatype-rules#type-precedence-list @@ -172,6 +169,8 @@ pub fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema match (a, b) { (Variant, _) | (_, Variant) => Variant, + (Primitive(DataType::Null), x) | (x, Primitive(DataType::Null)) => x, + (Primitive(p1), Primitive(p2)) => { merge_primitives(p1, p2).map(Primitive).unwrap_or(Variant) } diff --git a/src/variant_schema_agg.rs b/src/variant_schema_agg.rs index e834211..cd320fb 100644 --- a/src/variant_schema_agg.rs +++ b/src/variant_schema_agg.rs @@ -2,7 +2,10 @@ use arrow::array::AsArray; use arrow_schema::DataType; use datafusion::{ error::Result, - logical_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, function::AccumulatorArgs}, + logical_expr::{ + Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, + function::AccumulatorArgs, + }, scalar::ScalarValue, }; use parquet_variant_compute::VariantArray; @@ -10,7 +13,7 @@ use parquet_variant_compute::VariantArray; use crate::{VariantSchema, merge_variant_schema, print_schema, schema_from_variant}; #[derive(Debug, Hash, PartialEq, Eq)] -struct VariantSchemaAggUDAF { +pub struct VariantSchemaAggUDAF { signature: Signature, } @@ -48,7 +51,7 @@ impl AggregateUDFImpl for VariantSchemaAggUDAF { } #[derive(Debug)] -/// An accumulator to compute and store merged VariantSchame +/// An accumulator to compute and store merged VariantSchema pub struct VariantSchemaAccumulator { schema: VariantSchema, // This will store the current inferred schema } @@ -111,13 +114,12 @@ impl Accumulator for VariantSchemaAccumulator { mod test { use std::sync::Arc; - use arrow::array::{ArrayRef}; - use arrow_schema::{DataType, Field, Schema}; + use arrow::array::ArrayRef; + use arrow_schema::{DataType, Field, Fields, Schema}; use datafusion::{ - logical_expr::{ - Accumulator, - function::AccumulatorArgs, - }, + logical_expr::{Accumulator, function::AccumulatorArgs}, + physical_expr::PhysicalSortExpr, + physical_plan::expressions::col, scalar::ScalarValue, }; use parquet_variant_compute::VariantType; @@ -130,27 +132,36 @@ mod test { fn test_get_agg_variant_schema() { let b = build_variant_array_from_json_array(&[ Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), - None, Some(serde_json::json!({"wing": {"ding": "man"}})), ]); let b: ArrayRef = Arc::new(b.into_inner()); - let schema = Schema::new(vec![Field::new("b", DataType::Struct(_), true).with_extension_type(VariantType)]); + let schema = Schema::new(vec![ + Field::new( + "b", + DataType::Struct(Fields::from(vec![ + Field::new("metadata", DataType::Binary, true), + Field::new("value", DataType::Binary, true), + ])), + true, + ) + .with_extension_type(VariantType), + ]); let acc_args = AccumulatorArgs { return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), - schema: &Schema::empty(), + schema: &schema, ignore_nulls: false, - order_bys:, + order_bys: &[PhysicalSortExpr::new_default(col("b", &schema).unwrap())], is_reversed: false, name: "variant_schema_agg", is_distinct: false, - exprs:, + exprs: &[col("b", &schema).unwrap()], }; - let mut schema = VariantSchemaAccumulator::new(acc_args); - schema.update_batch(&[Arc::clone(&b)]).unwrap(); - let final_schema = schema.evaluate().unwrap(); + let mut variant_schema = VariantSchemaAccumulator::new(acc_args); + variant_schema.update_batch(&[Arc::clone(&b)]).unwrap(); + let final_schema = variant_schema.evaluate().unwrap(); assert_eq!( final_schema, ScalarValue::Utf8View(Some( @@ -159,37 +170,45 @@ mod test { ) } - // #[test] - // fn test_get_array_variant_schema() { - // let udaf = VariantSchemaAggUDAF::default(); - // let variant_array = build_variant_array_from_json_array(&[ - // Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), - // None, - // Some(serde_json::json!({"wing": {"ding": "man"}})), - // ]); - // let struct_array = variant_array.into_inner(); - // let args = build_array_udf_args(struct_array); - // let result = udaf.accumulator(acc_args) - // let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - // panic!() - // }; - // assert_eq!(schema, "OBJECT>") - // } - - // #[test] - // fn test_get_array_variant_conflicting_schema() { - // let udf = VariantSchemaAggUDAF::default(); - // let variant_array = build_variant_array_from_json_array(&[ - // Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), - // // None, - // Some(serde_json::json!({"wing": 123})), - // ]); - // let struct_array = variant_array.into_inner(); - // let args = build_array_udf_args(struct_array); - // let result = udf.invoke_with_args(args).unwrap(); - // let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - // panic!() - // }; - // assert_eq!(schema, "OBJECT") - // } + #[test] + fn test_get_array_variant_conflicting_schema() { + let b = build_variant_array_from_json_array(&[ + Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), + Some(serde_json::json!({"wing": 123})), + ]); + let b: ArrayRef = Arc::new(b.into_inner()); + + let schema = Schema::new(vec![ + Field::new( + "b", + DataType::Struct(Fields::from(vec![ + Field::new("metadata", DataType::Binary, true), + Field::new("value", DataType::Binary, true), + ])), + true, + ) + .with_extension_type(VariantType), + ]); + + let acc_args = AccumulatorArgs { + return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), + schema: &schema, + ignore_nulls: false, + order_bys: &[PhysicalSortExpr::new_default(col("b", &schema).unwrap())], + is_reversed: false, + name: "variant_schema_agg", + is_distinct: false, + exprs: &[col("b", &schema).unwrap()], + }; + + let mut variant_schema = VariantSchemaAccumulator::new(acc_args); + variant_schema.update_batch(&[Arc::clone(&b)]).unwrap(); + let final_schema = variant_schema.evaluate().unwrap(); + assert_eq!( + final_schema, + ScalarValue::Utf8View(Some( + "OBJECT".to_string() + )) + ) + } } diff --git a/tests/sqllogictests.rs b/tests/sqllogictests.rs index 93801f8..3316b5a 100644 --- a/tests/sqllogictests.rs +++ b/tests/sqllogictests.rs @@ -1,9 +1,12 @@ -use datafusion::{logical_expr::ScalarUDF, prelude::*}; +use datafusion::{ + logical_expr::{AggregateUDF, ScalarUDF}, + prelude::*, +}; use datafusion_sqllogictest::{DataFusion, TestContext}; use datafusion_variant::{ CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetUdf, VariantListConstruct, VariantListInsert, VariantObjectConstruct, VariantObjectInsert, VariantPretty, - VariantSchemaUDF, VariantToJsonUdf, + VariantSchemaAggUDAF, VariantSchemaUDF, VariantToJsonUdf, }; use indicatif::ProgressBar; use sqllogictest::strict_column_validator; @@ -30,7 +33,7 @@ async fn run_sqllogictests() -> Result<(), Box> { test_files.sort(); for test_file in test_files { - println!("Running test file: {:?}", test_file); + println!("Running test file: {test_file:?}"); let relative_path = test_file .strip_prefix(&test_files_dir) @@ -55,6 +58,7 @@ async fn run_sqllogictests() -> Result<(), Box> { ctx.register_udf(ScalarUDF::new_from_impl(VariantListInsert::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectInsert::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantSchemaUDF::default())); + ctx.register_udaf(AggregateUDF::new_from_impl(VariantSchemaAggUDAF::default())); let pb = ProgressBar::new(24); diff --git a/tests/test_files/variant_schema.slt b/tests/test_files/variant_schema.slt index 449704e..2a9437f 100644 --- a/tests/test_files/variant_schema.slt +++ b/tests/test_files/variant_schema.slt @@ -1,34 +1,70 @@ # tests the variant_schema udf # this function takes a Scalar Variant -# or a Variant Array and extracts it's SQL schema +# and extracts it's SQL schema -# Simple example with a Scalar value +# simple example with a scalar value query T SELECT variant_schema(json_to_variant('{"key": 123, "data": [4, 5]}')) ---- OBJECT, key: Int8> -# Conflicting element types in array +# conflicting element types in array query T SELECT variant_schema(json_to_variant('{"data": [{"a":"a"}, 5]}')) ---- OBJECT> -# A typed literal +# typed literal query T SELECT variant_schema(json_to_variant(123.4)) ---- Float64 -# Variant Arrays -statement ok -CREATE TABLE data as VALUES -(json_to_variant('{"foo": "bar", "wing": {"ding": "dong"}}')), -(json_to_variant('{"wing": 123}')), -(json_to_variant('{"wing": 123}')); +# explicit null +query T +SELECT variant_schema(json_to_variant('null')) +---- +Null -query T -SELECT variant_schema_agg(column1) from data; +# json null +query T +SELECT variant_schema(json_to_variant('{"a": null}')) +---- +OBJECT + +# numeric widening +query T +SELECT variant_schema(json_to_variant('[1, 2.5, 3]')) +---- +ARRAY + +# array of objects +query T +SELECT variant_schema(json_to_variant('[{"a":1},{"a":2}]')) +---- +ARRAY> + +# empty object +query T +SELECT variant_schema(json_to_variant('{}')) +---- +OBJECT<> + +# empty array +query T +SELECT variant_schema(json_to_variant('[]')) +---- +ARRAY + +# field ordering +query T +SELECT variant_schema(json_to_variant('{"b":1,"a":2}')) +---- +OBJECT + +# last key wins? +query T +SELECT variant_schema(json_to_variant('{"a": 1, "a": {"b":2}}')) ---- -OBJECT \ No newline at end of file +OBJECT> \ No newline at end of file diff --git a/tests/test_files/variant_schema_agg.slt b/tests/test_files/variant_schema_agg.slt new file mode 100644 index 0000000..4677a97 --- /dev/null +++ b/tests/test_files/variant_schema_agg.slt @@ -0,0 +1,114 @@ +# tests the variant_schema_agg udaf +# this function takes a Variant Array +# and extracts it's SQL schema + +# same schema +statement ok +CREATE TABLE t as VALUES +(json_to_variant('{"foo": "bar", "wing": {"ding": "dong"}}')), +(json_to_variant('{"wing": {"ding": "man"}}')); + +query T +SELECT variant_schema_agg(column1) from t; +---- +OBJECT> + +# conflicting schema +statement ok +CREATE TABLE t_conflicting as VALUES +(json_to_variant('{"foo": "bar", "wing": {"ding": "dong"}}')), +(json_to_variant('{"wing": 123}')); + +query T +SELECT variant_schema_agg(column1) from t_conflicting; +---- +OBJECT + +# null row +statement ok +CREATE TABLE t_nulls AS VALUES +(json_to_variant('{"a": 1}')), +(json_to_variant('null')), +(json_to_variant('{"a": 2}')); + +query T +SELECT variant_schema_agg(column1) FROM t_nulls; +---- +OBJECT + +# numeric widening +statement ok +CREATE TABLE t_nums AS VALUES +(json_to_variant('{"a": 1}')), +(json_to_variant('{"a": 2.5}')); + +query T +SELECT variant_schema_agg(column1) FROM t_nums; +---- +OBJECT + +# field appears later +statement ok +CREATE TABLE t_sparse AS VALUES +(json_to_variant('{}')), +(json_to_variant('{"a": 1}')); + +query T +SELECT variant_schema_agg(column1) FROM t_sparse; +---- +OBJECT + +# conflicting array of objects +statement ok +CREATE TABLE t_arr_objs AS VALUES +(json_to_variant('[{"a":1}]')), +(json_to_variant('[{"a":"x"}]')); + +query T +SELECT variant_schema_agg(column1) FROM t_arr_objs; +---- +ARRAY> + +# empty aggregates +statement ok +CREATE TABLE t_empty AS VALUES +(json_to_variant('{}')), +(json_to_variant('{}')); + +query T +SELECT variant_schema_agg(column1) FROM t_empty; +---- +OBJECT<> + +# field ordering +statement ok +CREATE TABLE t_order AS VALUES +(json_to_variant('{"b":1}')), +(json_to_variant('{"a":2}')); + +query T +SELECT variant_schema_agg(column1) FROM t_order; +---- +OBJECT + +# root conflict +statement ok +CREATE TABLE t_root_conflict AS VALUES +(json_to_variant('{"a":1}')), +(json_to_variant('[1,2,3]')); + +query T +SELECT variant_schema_agg(column1) FROM t_root_conflict; +---- +VARIANT + +# mixed root +statement ok +CREATE TABLE t_mixed AS VALUES +(json_to_variant('1')), +(json_to_variant('{"a": 1}')); + +query T +SELECT variant_schema_agg(column1) FROM t_mixed; +---- +VARIANT \ No newline at end of file From 97b0ddb23936323af09d17293e571f9a83c06068 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 29 Dec 2025 13:13:44 -0500 Subject: [PATCH 13/32] cargo fmt --- src/variant_schema_agg.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/variant_schema_agg.rs b/src/variant_schema_agg.rs index cd320fb..1bb9540 100644 --- a/src/variant_schema_agg.rs +++ b/src/variant_schema_agg.rs @@ -206,9 +206,7 @@ mod test { let final_schema = variant_schema.evaluate().unwrap(); assert_eq!( final_schema, - ScalarValue::Utf8View(Some( - "OBJECT".to_string() - )) + ScalarValue::Utf8View(Some("OBJECT".to_string())) ) } } From a78b9a6fa75aebc9a3c2fd56c38e0e2c1d493cd1 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 23 Feb 2026 22:29:56 -0500 Subject: [PATCH 14/32] Redo variant_schema func --- src/variant_schema.rs | 126 +++++++++++++++++++++++++++++------------- 1 file changed, 87 insertions(+), 39 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 804d390..6d75607 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -1,3 +1,4 @@ +use arrow::array::{ArrayRef, StringViewArray}; use arrow_schema::{DataType, TimeUnit}; use datafusion::{ common::exec_err, @@ -8,6 +9,7 @@ use datafusion::{ use parquet_variant::Variant; use parquet_variant_compute::VariantArray; use std::collections::BTreeMap; +use std::sync::Arc; #[derive(Debug, Hash, PartialEq, Eq)] pub struct VariantSchemaUDF { @@ -17,41 +19,30 @@ pub struct VariantSchemaUDF { impl Default for VariantSchemaUDF { fn default() -> Self { Self { - signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable), + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), } } } -/// Schema inference rules for VARIANT values. +/// Infers a schema description for one VARIANT value. /// /// The inferred schema can be one of four logical forms: /// - Primitive: a concrete SQL / Arrow data type -/// - Array: ARRAY, where `inner` is the merged element schema -/// - Object: OBJECT, merged field-wise by name -/// - Variant: fallback when no common schema can be determined +/// - Array: `ARRAY`, where `inner` is merged across elements in that array value +/// - Object: `OBJECT`, merged recursively per field +/// - Variant: fallback when no common inner schema can be determined /// -/// ## Scalar input -/// When the input is a single VARIANT value: -/// - Primitive values map directly to their corresponding data type -/// - Arrays infer a common element schema across all elements -/// - Objects infer schemas per field recursively -/// - Mixed or incompatible types resolve to VARIANT +/// Execution semantics: +/// - Scalar input: infer one schema string for that value. +/// - Columnar input: infer one schema string per row (vectorized row-wise behavior). +/// - This function does not merge schemas across rows. For cross-row/group merge use +/// `variant_schema_agg`. /// -/// ## Array input -/// When the input is an array of VARIANT values: -/// - Each element is inferred independently -/// - Schemas are merged across rows -/// - If any merge step resolves to VARIANT, inference short-circuits -/// -/// ## Merge rules -/// - If outer (or inner) kinds differ, the result is VARIANT +/// Merge rules (within one VARIANT value only): +/// - If outer (or inner) kinds differ, the result is `VARIANT` /// - Primitive types are merged using widening / least-common-type rules /// - Arrays merge by merging their element schemas -/// - Objects merge field-by-field: -/// - Missing fields are allowed -/// - Field schemas are merged independently -/// - A field becomes VARIANT if its values are incompatible -/// +/// - Objects merge field-by-field; missing fields are allowed #[derive(Debug, PartialEq, Eq, Clone)] pub enum VariantSchema { Primitive(DataType), @@ -211,22 +202,32 @@ pub fn print_schema(schema: &VariantSchema) -> String { } } -/// Final function used to retrieve the schema from a single Variant +/// Retrieve schema text from a VARIANT scalar or array (row-wise for arrays). fn infer_variant_schema(variant: &ColumnarValue) -> Result { - if let ColumnarValue::Scalar(scalar) = variant { - let ScalarValue::Struct(struct_array) = scalar else { - return exec_err!("Unsupported data type: {}", scalar.data_type()); - }; - let variant_array = VariantArray::try_new(struct_array.as_ref())?; - let v = variant_array.value(0); - - let schema = schema_from_variant(&v); + match variant { + ColumnarValue::Scalar(scalar) => { + let ScalarValue::Struct(struct_array) = scalar else { + return exec_err!("Unsupported data type: {}", scalar.data_type()); + }; + + let variant_array = VariantArray::try_new(struct_array.as_ref())?; + let v = variant_array.value(0); + let schema = schema_from_variant(&v); + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + print_schema(&schema), + )))) + } + ColumnarValue::Array(array) => { + let variant_array = VariantArray::try_new(array.as_ref())?; + let out = variant_array + .iter() + .map(|v| v.map(|v| print_schema(&schema_from_variant(&v)))) + .collect::>(); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - print_schema(&schema), - )))) - } else { - exec_err!("Expected a ScalarValue, got: {:?}", variant) + let out: StringViewArray = out.into(); + Ok(ColumnarValue::Array(Arc::new(out) as ArrayRef)) + } } } @@ -260,6 +261,7 @@ impl ScalarUDFImpl for VariantSchemaUDF { #[cfg(test)] mod tests { + use arrow::array::StringViewArray; use arrow::array::StructArray; use arrow_schema::{DataType, Field, Fields}; use chrono::{DateTime, NaiveDate, NaiveTime}; @@ -271,7 +273,10 @@ mod tests { use parquet_variant_compute::{VariantArray, VariantType}; use std::sync::Arc; - use crate::{VariantSchemaUDF, shared::build_variant_array_from_json}; + use crate::{ + VariantSchemaUDF, + shared::{build_variant_array_from_json, build_variant_array_from_json_array}, + }; fn build_scalar_udf_args(struct_array: StructArray) -> ScalarFunctionArgs { let return_field = Arc::new(Field::new("result", DataType::Utf8View, true)); @@ -290,6 +295,21 @@ mod tests { } } + fn build_array_udf_args(struct_array: StructArray) -> ScalarFunctionArgs { + let return_field = Arc::new(Field::new("result", DataType::Utf8View, true)); + let arg_field = Arc::new( + Field::new("input", DataType::Struct(Fields::empty()), true) + .with_extension_type(VariantType), + ); + ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(struct_array))], + arg_fields: vec![arg_field], + number_rows: Default::default(), + return_field, + config_options: Default::default(), + } + } + #[test] fn test_get_single_typed_null_variant_schema() { let udf = VariantSchemaUDF::default(); @@ -474,4 +494,32 @@ mod tests { }; assert_eq!(schema, "OBJECT>") } + + #[test] + fn test_get_columnar_variant_schema() { + let udf = VariantSchemaUDF::default(); + let variant_array = build_variant_array_from_json_array(&[ + Some(serde_json::json!({"a": 1})), + Some(serde_json::json!([1, 2, 3])), + ]); + let struct_array = variant_array.into_inner(); + let args = build_array_udf_args(struct_array); + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Array(array) = result else { + panic!("expected array output") + }; + + let strings = array + .as_any() + .downcast_ref::() + .expect("expected Utf8View array") + .iter() + .collect::>(); + + assert_eq!( + strings, + vec![Some("OBJECT"), Some("ARRAY")] + ); + } } From 6d461acbc0ebb34b2e59ef4dc61d6511ce8223c5 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 23 Feb 2026 22:30:26 -0500 Subject: [PATCH 15/32] Add variant_schema array test --- tests/test_files/variant_schema.slt | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/test_files/variant_schema.slt b/tests/test_files/variant_schema.slt index 2a9437f..380c826 100644 --- a/tests/test_files/variant_schema.slt +++ b/tests/test_files/variant_schema.slt @@ -1,6 +1,6 @@ # tests the variant_schema udf -# this function takes a Scalar Variant -# and extracts it's SQL schema +# this function takes a VARIANT expression +# and extracts each row's SQL schema # simple example with a scalar value query T @@ -8,6 +8,18 @@ SELECT variant_schema(json_to_variant('{"key": 123, "data": [4, 5]}')) ---- OBJECT, key: Int8> +# column input (row-wise, non-aggregate) +statement ok +CREATE TABLE t_col AS VALUES +(json_to_variant('{"a": 1}')), +(json_to_variant('[1, 2, 3]')); + +query T +SELECT variant_schema(column1) FROM t_col ORDER BY 1; +---- +ARRAY +OBJECT + # conflicting element types in array query T @@ -67,4 +79,4 @@ OBJECT query T SELECT variant_schema(json_to_variant('{"a": 1, "a": {"b":2}}')) ---- -OBJECT> \ No newline at end of file +OBJECT> From b5c3ec2a9192f200f8f0ea0b4db17c7e5bb91c13 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 23 Feb 2026 22:48:02 -0500 Subject: [PATCH 16/32] Encode state VariantSchema to bytes state for AUDF --- src/variant_schema.rs | 105 +++++++++++++++++++++++++++++++ src/variant_schema_agg.rs | 129 +++++++++++++++++++++++++++++++++----- 2 files changed, 219 insertions(+), 15 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 6d75607..332582b 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -51,6 +51,111 @@ pub enum VariantSchema { Variant, } +impl VariantSchema { + pub fn to_state_bytes(&self) -> Vec { + let mut out = Vec::new(); + encode_variant_schema(self, &mut out); + out + } + + pub fn from_state_bytes(bytes: &[u8]) -> Result { + let mut offset = 0usize; + let decoded = decode_variant_schema(bytes, &mut offset)?; + if offset != bytes.len() { + return exec_err!("invalid variant_schema_agg state: trailing bytes"); + } + Ok(decoded) + } +} + +fn encode_len_prefixed_bytes(out: &mut Vec, bytes: &[u8]) { + out.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + out.extend_from_slice(bytes); +} + +fn read_u8(input: &[u8], offset: &mut usize) -> Result { + let Some(v) = input.get(*offset) else { + return exec_err!("invalid variant_schema_agg state: missing tag"); + }; + *offset += 1; + Ok(*v) +} + +fn read_u32(input: &[u8], offset: &mut usize) -> Result { + let Some(raw) = input.get(*offset..(*offset + 4)) else { + return exec_err!("invalid variant_schema_agg state: missing u32"); + }; + *offset += 4; + Ok(u32::from_le_bytes([raw[0], raw[1], raw[2], raw[3]])) +} + +fn read_len_prefixed_bytes<'a>(input: &'a [u8], offset: &mut usize) -> Result<&'a [u8]> { + let len = read_u32(input, offset)? as usize; + let Some(raw) = input.get(*offset..(*offset + len)) else { + return exec_err!("invalid variant_schema_agg state: truncated payload"); + }; + *offset += len; + Ok(raw) +} + +fn encode_variant_schema(schema: &VariantSchema, out: &mut Vec) { + match schema { + VariantSchema::Primitive(dtype) => { + out.push(0); + encode_len_prefixed_bytes(out, dtype.to_string().as_bytes()); + } + VariantSchema::Array(inner) => { + out.push(1); + encode_variant_schema(inner, out); + } + VariantSchema::Object(fields) => { + out.push(2); + out.extend_from_slice(&(fields.len() as u32).to_le_bytes()); + for (key, value) in fields { + encode_len_prefixed_bytes(out, key.as_bytes()); + encode_variant_schema(value, out); + } + } + VariantSchema::Variant => out.push(3), + } +} + +fn decode_variant_schema(input: &[u8], offset: &mut usize) -> Result { + match read_u8(input, offset)? { + 0 => { + let raw = read_len_prefixed_bytes(input, offset)?; + let dtype_str = match std::str::from_utf8(raw) { + Ok(v) => v, + Err(e) => return exec_err!("invalid variant_schema_agg state: {e}"), + }; + let dtype = match dtype_str.parse::() { + Ok(v) => v, + Err(e) => return exec_err!("invalid variant_schema_agg datatype state: {e}"), + }; + Ok(VariantSchema::Primitive(dtype)) + } + 1 => Ok(VariantSchema::Array(Box::new(decode_variant_schema( + input, offset, + )?))), + 2 => { + let count = read_u32(input, offset)? as usize; + let mut fields = BTreeMap::new(); + for _ in 0..count { + let key_raw = read_len_prefixed_bytes(input, offset)?; + let key = match std::str::from_utf8(key_raw) { + Ok(v) => v.to_string(), + Err(e) => return exec_err!("invalid variant_schema_agg field key: {e}"), + }; + let value = decode_variant_schema(input, offset)?; + fields.insert(key, value); + } + Ok(VariantSchema::Object(fields)) + } + 3 => Ok(VariantSchema::Variant), + tag => exec_err!("invalid variant_schema_agg state tag: {tag}"), + } +} + /// This function extracts the schema from a single Variant scalar pub fn schema_from_variant(v: &Variant) -> VariantSchema { match v { diff --git a/src/variant_schema_agg.rs b/src/variant_schema_agg.rs index 1bb9540..72e581e 100644 --- a/src/variant_schema_agg.rs +++ b/src/variant_schema_agg.rs @@ -1,17 +1,33 @@ use arrow::array::AsArray; -use arrow_schema::DataType; +use arrow_schema::{DataType, Field, FieldRef}; use datafusion::{ error::Result, logical_expr::{ Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, - function::AccumulatorArgs, + function::{AccumulatorArgs, StateFieldsArgs}, + utils::format_state_name, }, scalar::ScalarValue, }; use parquet_variant_compute::VariantArray; +use std::sync::Arc; -use crate::{VariantSchema, merge_variant_schema, print_schema, schema_from_variant}; +use crate::{ + VariantSchema, merge_variant_schema, print_schema, schema_from_variant, + shared::try_parse_binary_columnar, +}; +/// Aggregate schema inference for VARIANT values across rows. +/// +/// This function infers per-row schemas using `schema_from_variant` and merges +/// them into a single schema per group. +/// +/// Semantics: +/// - Input: one VARIANT expression +/// - Output: one schema string per aggregate group +/// - Row filtering should be done via SQL `FILTER (WHERE ...)` +/// +/// Use `variant_schema` for row-wise (non-aggregate) inference. #[derive(Debug, Hash, PartialEq, Eq)] pub struct VariantSchemaAggUDAF { signature: Signature, @@ -20,7 +36,7 @@ pub struct VariantSchemaAggUDAF { impl Default for VariantSchemaAggUDAF { fn default() -> Self { Self { - signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable), + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), } } } @@ -42,6 +58,16 @@ impl AggregateUDFImpl for VariantSchemaAggUDAF { Ok(DataType::Utf8View) } + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let fields = vec![Arc::new(Field::new( + format_state_name(args.name, "variant_schema"), + DataType::Binary, + true, + ))]; + + Ok(fields.into_iter().chain(args.ordering_fields.to_vec()).collect()) + } + fn accumulator( &self, acc_args: datafusion::logical_expr::function::AccumulatorArgs, @@ -50,15 +76,14 @@ impl AggregateUDFImpl for VariantSchemaAggUDAF { } } +/// Accumulator state for `variant_schema_agg`. #[derive(Debug)] -/// An accumulator to compute and store merged VariantSchema pub struct VariantSchemaAccumulator { - schema: VariantSchema, // This will store the current inferred schema + schema: VariantSchema, } impl VariantSchemaAccumulator { fn new(_acc_args: AccumulatorArgs) -> Self { - // Initialize with Variant as the starting schema Self { schema: VariantSchema::Primitive(DataType::Null), } @@ -67,10 +92,7 @@ impl VariantSchemaAccumulator { impl Accumulator for VariantSchemaAccumulator { fn state(&mut self) -> Result> { - // Return the current state (the inferred schema) - Ok(vec![ScalarValue::Utf8View(Some(print_schema( - &self.schema, - )))]) + Ok(vec![ScalarValue::Binary(Some(self.schema.to_state_bytes()))]) } fn evaluate(&mut self) -> Result { @@ -93,11 +115,9 @@ impl Accumulator for VariantSchemaAccumulator { } fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> Result<()> { - // Merge schemas from other states (batches) for state in states { - let variant_array = VariantArray::try_new(state.as_struct())?; - for variant in variant_array.iter().flatten() { - let new_schema = schema_from_variant(&variant); + for encoded_state in try_parse_binary_columnar(state)?.into_iter().flatten() { + let new_schema = VariantSchema::from_state_bytes(encoded_state)?; self.schema = merge_variant_schema(self.schema.clone(), new_schema); } } @@ -209,4 +229,83 @@ mod test { ScalarValue::Utf8View(Some("OBJECT".to_string())) ) } + + #[test] + fn test_merge_batch_from_state_roundtrip() { + let schema = Schema::new(vec![ + Field::new( + "b", + DataType::Struct(Fields::from(vec![ + Field::new("metadata", DataType::Binary, true), + Field::new("value", DataType::Binary, true), + ])), + true, + ) + .with_extension_type(VariantType), + ]); + + let b1 = build_variant_array_from_json_array(&[Some(serde_json::json!({"a": 1}))]); + let b1: ArrayRef = Arc::new(b1.into_inner()); + + let b2 = build_variant_array_from_json_array(&[Some(serde_json::json!({"a": 2.5}))]); + let b2: ArrayRef = Arc::new(b2.into_inner()); + + let acc1_args = AccumulatorArgs { + return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), + schema: &schema, + ignore_nulls: false, + order_bys: &[PhysicalSortExpr::new_default(col("b", &schema).unwrap())], + is_reversed: false, + name: "variant_schema_agg", + is_distinct: false, + exprs: &[col("b", &schema).unwrap()], + }; + let acc2_args = AccumulatorArgs { + return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), + schema: &schema, + ignore_nulls: false, + order_bys: &[PhysicalSortExpr::new_default(col("b", &schema).unwrap())], + is_reversed: false, + name: "variant_schema_agg", + is_distinct: false, + exprs: &[col("b", &schema).unwrap()], + }; + let merged_args = AccumulatorArgs { + return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), + schema: &schema, + ignore_nulls: false, + order_bys: &[PhysicalSortExpr::new_default(col("b", &schema).unwrap())], + is_reversed: false, + name: "variant_schema_agg", + is_distinct: false, + exprs: &[col("b", &schema).unwrap()], + }; + + let mut acc1 = VariantSchemaAccumulator::new(acc1_args); + acc1.update_batch(&[Arc::clone(&b1)]).unwrap(); + let state_1 = acc1 + .state() + .unwrap() + .into_iter() + .map(|s| s.to_array().unwrap()) + .collect::>(); + + let mut acc2 = VariantSchemaAccumulator::new(acc2_args); + acc2.update_batch(&[Arc::clone(&b2)]).unwrap(); + let state_2 = acc2 + .state() + .unwrap() + .into_iter() + .map(|s| s.to_array().unwrap()) + .collect::>(); + + let mut merged = VariantSchemaAccumulator::new(merged_args); + merged.merge_batch(&state_1).unwrap(); + merged.merge_batch(&state_2).unwrap(); + + assert_eq!( + merged.evaluate().unwrap(), + ScalarValue::Utf8View(Some("OBJECT".to_string())) + ); + } } From bdf59b0c6975817c6a1f58642b015e4b0ebbc314 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Wed, 25 Feb 2026 11:52:03 -0500 Subject: [PATCH 17/32] agg function quick stop if schema is Variant --- src/variant_schema_agg.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/variant_schema_agg.rs b/src/variant_schema_agg.rs index 72e581e..b7cf0e5 100644 --- a/src/variant_schema_agg.rs +++ b/src/variant_schema_agg.rs @@ -101,6 +101,10 @@ impl Accumulator for VariantSchemaAccumulator { } fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> Result<()> { + if self.schema == VariantSchema::Variant { + return Ok(()); + } + // We're assuming the input is an array of variants for value in values { // Ensure we are dealing with VariantArray and extract the variant values @@ -109,16 +113,26 @@ impl Accumulator for VariantSchemaAccumulator { let new_schema = schema_from_variant(&variant); // Merge the new schema with the current schema self.schema = merge_variant_schema(self.schema.clone(), new_schema); + if self.schema == VariantSchema::Variant { + return Ok(()); + } } } Ok(()) } fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> Result<()> { + if self.schema == VariantSchema::Variant { + return Ok(()); + } + for state in states { for encoded_state in try_parse_binary_columnar(state)?.into_iter().flatten() { let new_schema = VariantSchema::from_state_bytes(encoded_state)?; self.schema = merge_variant_schema(self.schema.clone(), new_schema); + if self.schema == VariantSchema::Variant { + return Ok(()); + } } } Ok(()) From 30c70c4de9a4893f0afb92b4f8db8127d0f3098a Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Wed, 25 Feb 2026 12:01:02 -0500 Subject: [PATCH 18/32] early fold for Variant in list schemas --- src/variant_schema.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 332582b..6d342f8 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -171,8 +171,15 @@ pub fn schema_from_variant(v: &Variant) -> VariantSchema { let inner = list .iter() .map(|v| schema_from_variant(&v)) - .reduce(merge_variant_schema) - .unwrap_or(VariantSchema::Primitive(DataType::Null)); + .try_fold(VariantSchema::Primitive(DataType::Null), |acc, next| { + let merged = merge_variant_schema(acc, next); + if merged == VariantSchema::Variant { + Err(merged) + } else { + Ok(merged) + } + }) + .unwrap_or_else(|schema| schema); VariantSchema::Array(Box::new(inner)) } @@ -622,9 +629,6 @@ mod tests { .iter() .collect::>(); - assert_eq!( - strings, - vec![Some("OBJECT"), Some("ARRAY")] - ); + assert_eq!(strings, vec![Some("OBJECT"), Some("ARRAY")]); } } From b593aeadc1c2559c41037f4640c92a562b055e41 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Wed, 25 Feb 2026 12:01:15 -0500 Subject: [PATCH 19/32] cargo fmt --- src/variant_schema_agg.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/variant_schema_agg.rs b/src/variant_schema_agg.rs index b7cf0e5..6c13d66 100644 --- a/src/variant_schema_agg.rs +++ b/src/variant_schema_agg.rs @@ -65,7 +65,10 @@ impl AggregateUDFImpl for VariantSchemaAggUDAF { true, ))]; - Ok(fields.into_iter().chain(args.ordering_fields.to_vec()).collect()) + Ok(fields + .into_iter() + .chain(args.ordering_fields.to_vec()) + .collect()) } fn accumulator( @@ -92,7 +95,9 @@ impl VariantSchemaAccumulator { impl Accumulator for VariantSchemaAccumulator { fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Binary(Some(self.schema.to_state_bytes()))]) + Ok(vec![ScalarValue::Binary(Some( + self.schema.to_state_bytes(), + ))]) } fn evaluate(&mut self) -> Result { From 3db8535299e1d025e4abc3f7e8f95c39f47166c6 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Wed, 25 Feb 2026 12:27:42 -0500 Subject: [PATCH 20/32] tests cleanup --- src/variant_schema.rs | 159 +--------------------------- src/variant_schema_agg.rs | 82 -------------- tests/test_files/variant_schema.slt | 15 +++ 3 files changed, 16 insertions(+), 240 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 6d342f8..524732a 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -373,7 +373,6 @@ impl ScalarUDFImpl for VariantSchemaUDF { #[cfg(test)] mod tests { - use arrow::array::StringViewArray; use arrow::array::StructArray; use arrow_schema::{DataType, Field, Fields}; use chrono::{DateTime, NaiveDate, NaiveTime}; @@ -385,10 +384,7 @@ mod tests { use parquet_variant_compute::{VariantArray, VariantType}; use std::sync::Arc; - use crate::{ - VariantSchemaUDF, - shared::{build_variant_array_from_json, build_variant_array_from_json_array}, - }; + use crate::VariantSchemaUDF; fn build_scalar_udf_args(struct_array: StructArray) -> ScalarFunctionArgs { let return_field = Arc::new(Field::new("result", DataType::Utf8View, true)); @@ -407,49 +403,6 @@ mod tests { } } - fn build_array_udf_args(struct_array: StructArray) -> ScalarFunctionArgs { - let return_field = Arc::new(Field::new("result", DataType::Utf8View, true)); - let arg_field = Arc::new( - Field::new("input", DataType::Struct(Fields::empty()), true) - .with_extension_type(VariantType), - ); - ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(struct_array))], - arg_fields: vec![arg_field], - number_rows: Default::default(), - return_field, - config_options: Default::default(), - } - } - - #[test] - fn test_get_single_typed_null_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = Variant::Null; - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Null") - } - - #[test] - fn test_get_single_typed_int32_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = Variant::from(1234i32); - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Int32") - } - #[test] fn test_get_single_typed_date_variant_schema() { let udf = VariantSchemaUDF::default(); @@ -493,48 +446,6 @@ mod tests { assert_eq!(schema, "Decimal32(4, 1)") } - #[test] - fn test_get_single_typed_float_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = Variant::from(123.4f32); - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Float32") - } - - #[test] - fn test_get_single_typed_double_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = Variant::from(123.4f64); - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Float64") - } - - #[test] - fn test_get_single_typed_bool_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = Variant::BooleanTrue; - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Boolean") - } - #[test] fn test_get_single_typed_binary_variant_schema() { let udf = VariantSchemaUDF::default(); @@ -549,20 +460,6 @@ mod tests { assert_eq!(schema, "Binary") } - #[test] - fn test_get_single_typed_string_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = Variant::from("foo"); - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Utf8") - } - #[test] fn test_get_single_typed_time_variant_schema() { let udf = VariantSchemaUDF::default(); @@ -577,58 +474,4 @@ mod tests { assert_eq!(schema, "Time64(µs)") } - #[test] - fn test_get_single_struct_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant_array = build_variant_array_from_json(&serde_json::json!({ - "key": 123, "data": [4, 5] - })); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "OBJECT, key: Int8>") - } - - #[test] - fn test_get_single_struct_variant_conflicting_schema() { - let udf = VariantSchemaUDF::default(); - let variant_array = build_variant_array_from_json(&serde_json::json!({ - "data": [{"a":"a"}, 5] - })); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "OBJECT>") - } - - #[test] - fn test_get_columnar_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant_array = build_variant_array_from_json_array(&[ - Some(serde_json::json!({"a": 1})), - Some(serde_json::json!([1, 2, 3])), - ]); - let struct_array = variant_array.into_inner(); - let args = build_array_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - - let ColumnarValue::Array(array) = result else { - panic!("expected array output") - }; - - let strings = array - .as_any() - .downcast_ref::() - .expect("expected Utf8View array") - .iter() - .collect::>(); - - assert_eq!(strings, vec![Some("OBJECT"), Some("ARRAY")]); - } } diff --git a/src/variant_schema_agg.rs b/src/variant_schema_agg.rs index 6c13d66..22d9cb4 100644 --- a/src/variant_schema_agg.rs +++ b/src/variant_schema_agg.rs @@ -167,88 +167,6 @@ mod test { shared::build_variant_array_from_json_array, variant_schema_agg::VariantSchemaAccumulator, }; - #[test] - fn test_get_agg_variant_schema() { - let b = build_variant_array_from_json_array(&[ - Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), - Some(serde_json::json!({"wing": {"ding": "man"}})), - ]); - let b: ArrayRef = Arc::new(b.into_inner()); - - let schema = Schema::new(vec![ - Field::new( - "b", - DataType::Struct(Fields::from(vec![ - Field::new("metadata", DataType::Binary, true), - Field::new("value", DataType::Binary, true), - ])), - true, - ) - .with_extension_type(VariantType), - ]); - - let acc_args = AccumulatorArgs { - return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), - schema: &schema, - ignore_nulls: false, - order_bys: &[PhysicalSortExpr::new_default(col("b", &schema).unwrap())], - is_reversed: false, - name: "variant_schema_agg", - is_distinct: false, - exprs: &[col("b", &schema).unwrap()], - }; - - let mut variant_schema = VariantSchemaAccumulator::new(acc_args); - variant_schema.update_batch(&[Arc::clone(&b)]).unwrap(); - let final_schema = variant_schema.evaluate().unwrap(); - assert_eq!( - final_schema, - ScalarValue::Utf8View(Some( - "OBJECT>".to_string() - )) - ) - } - - #[test] - fn test_get_array_variant_conflicting_schema() { - let b = build_variant_array_from_json_array(&[ - Some(serde_json::json!({"foo": "bar", "wing": {"ding": "dong"}})), - Some(serde_json::json!({"wing": 123})), - ]); - let b: ArrayRef = Arc::new(b.into_inner()); - - let schema = Schema::new(vec![ - Field::new( - "b", - DataType::Struct(Fields::from(vec![ - Field::new("metadata", DataType::Binary, true), - Field::new("value", DataType::Binary, true), - ])), - true, - ) - .with_extension_type(VariantType), - ]); - - let acc_args = AccumulatorArgs { - return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), - schema: &schema, - ignore_nulls: false, - order_bys: &[PhysicalSortExpr::new_default(col("b", &schema).unwrap())], - is_reversed: false, - name: "variant_schema_agg", - is_distinct: false, - exprs: &[col("b", &schema).unwrap()], - }; - - let mut variant_schema = VariantSchemaAccumulator::new(acc_args); - variant_schema.update_batch(&[Arc::clone(&b)]).unwrap(); - let final_schema = variant_schema.evaluate().unwrap(); - assert_eq!( - final_schema, - ScalarValue::Utf8View(Some("OBJECT".to_string())) - ) - } - #[test] fn test_merge_batch_from_state_roundtrip() { let schema = Schema::new(vec![ diff --git a/tests/test_files/variant_schema.slt b/tests/test_files/variant_schema.slt index 380c826..0eafdcb 100644 --- a/tests/test_files/variant_schema.slt +++ b/tests/test_files/variant_schema.slt @@ -33,6 +33,21 @@ SELECT variant_schema(json_to_variant(123.4)) ---- Float64 +# explicit string primitive +query T +SELECT variant_schema(json_to_variant('"foo"')) +---- +Utf8 + +# explicit boolean primitive +query T +SELECT variant_schema(json_to_variant('true')) +---- +Boolean + +# TODO: add DATE/TIME/TIMESTAMP/DECIMAL slt coverage once cast_to_variant +# scalar/columnar return-field nullability mismatch is resolved for typed casts. + # explicit null query T SELECT variant_schema(json_to_variant('null')) From a041c8f50db6253d3a6f3aaaf29cfae7abc41762 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Thu, 26 Feb 2026 11:25:52 -0500 Subject: [PATCH 21/32] update dependency --- Cargo.lock | 2 -- Cargo.toml | 2 +- src/variant_schema_agg.rs | 30 ++++++++++++++++++++++++------ 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 28958e3..fbb0a97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -601,9 +601,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ "iana-time-zone", - "js-sys", "num-traits", - "wasm-bindgen", "windows-link", ] diff --git a/Cargo.toml b/Cargo.toml index dbe2fce..64fc55d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ parquet-variant = "=57.2.0" [dev-dependencies] anyhow = "1.0.100" arrow-cast = "=57.2.0" +chrono = { version = "0.4", default-features = false, features = ["clock"] } serde_json = "1.0.145" flate2 = "1.0" tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] } @@ -24,4 +25,3 @@ env_logger = "0.11" insta = "1.43.2" rustyline = { version = "14.0", features = ["derive"] } - diff --git a/src/variant_schema_agg.rs b/src/variant_schema_agg.rs index 22d9cb4..4fe8841 100644 --- a/src/variant_schema_agg.rs +++ b/src/variant_schema_agg.rs @@ -187,35 +187,53 @@ mod test { let b2 = build_variant_array_from_json_array(&[Some(serde_json::json!({"a": 2.5}))]); let b2: ArrayRef = Arc::new(b2.into_inner()); + let expr = col("b", &schema).unwrap(); + let order_bys = vec![PhysicalSortExpr::new_default(Arc::clone(&expr))]; + let exprs = vec![expr]; + let expr_fields = vec![Arc::new( + Field::new( + "b", + DataType::Struct(Fields::from(vec![ + Field::new("metadata", DataType::Binary, true), + Field::new("value", DataType::Binary, true), + ])), + true, + ) + .with_extension_type(VariantType), + )]; + let acc1_args = AccumulatorArgs { return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), schema: &schema, ignore_nulls: false, - order_bys: &[PhysicalSortExpr::new_default(col("b", &schema).unwrap())], + order_bys: &order_bys, is_reversed: false, name: "variant_schema_agg", is_distinct: false, - exprs: &[col("b", &schema).unwrap()], + exprs: &exprs, + expr_fields: &expr_fields, }; let acc2_args = AccumulatorArgs { return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), schema: &schema, ignore_nulls: false, - order_bys: &[PhysicalSortExpr::new_default(col("b", &schema).unwrap())], + order_bys: &order_bys, is_reversed: false, name: "variant_schema_agg", is_distinct: false, - exprs: &[col("b", &schema).unwrap()], + exprs: &exprs, + expr_fields: &expr_fields, }; let merged_args = AccumulatorArgs { return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), schema: &schema, ignore_nulls: false, - order_bys: &[PhysicalSortExpr::new_default(col("b", &schema).unwrap())], + order_bys: &order_bys, is_reversed: false, name: "variant_schema_agg", is_distinct: false, - exprs: &[col("b", &schema).unwrap()], + exprs: &exprs, + expr_fields: &expr_fields, }; let mut acc1 = VariantSchemaAccumulator::new(acc1_args); From f395e406ae754bc0ff3c5ff8d09e334e03bbb7fb Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Thu, 26 Feb 2026 12:23:11 -0500 Subject: [PATCH 22/32] Move date tests to slt --- Cargo.lock | 1 - Cargo.toml | 2 - src/variant_schema.rs | 105 ---------------------------- tests/test_files/variant_schema.slt | 34 +++++++-- 4 files changed, 30 insertions(+), 112 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fbb0a97..c3b90e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1620,7 +1620,6 @@ dependencies = [ "arrow", "arrow-cast", "arrow-schema", - "chrono", "datafusion", "datafusion-sqllogictest", "env_logger", diff --git a/Cargo.toml b/Cargo.toml index 64fc55d..afd412c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ parquet-variant = "=57.2.0" [dev-dependencies] anyhow = "1.0.100" arrow-cast = "=57.2.0" -chrono = { version = "0.4", default-features = false, features = ["clock"] } serde_json = "1.0.145" flate2 = "1.0" tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] } @@ -24,4 +23,3 @@ indicatif = "0.18" env_logger = "0.11" insta = "1.43.2" rustyline = { version = "14.0", features = ["derive"] } - diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 524732a..37cc17a 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -370,108 +370,3 @@ impl ScalarUDFImpl for VariantSchemaUDF { infer_variant_schema(arg) } } - -#[cfg(test)] -mod tests { - use arrow::array::StructArray; - use arrow_schema::{DataType, Field, Fields}; - use chrono::{DateTime, NaiveDate, NaiveTime}; - use datafusion::{ - logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}, - scalar::ScalarValue, - }; - use parquet_variant::{Variant, VariantDecimal4}; - use parquet_variant_compute::{VariantArray, VariantType}; - use std::sync::Arc; - - use crate::VariantSchemaUDF; - - fn build_scalar_udf_args(struct_array: StructArray) -> ScalarFunctionArgs { - let return_field = Arc::new(Field::new("result", DataType::Utf8View, true)); - let arg_field = Arc::new( - Field::new("input", DataType::Struct(Fields::empty()), true) - .with_extension_type(VariantType), - ); - ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::Struct(Arc::new( - struct_array, - )))], - arg_fields: vec![arg_field], - number_rows: Default::default(), - return_field, - config_options: Default::default(), - } - } - - #[test] - fn test_get_single_typed_date_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = Variant::from(NaiveDate::from_ymd_opt(1990, 1, 1).expect("Expect NaiveDate")); - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Date32") - } - - #[test] - fn test_get_single_typed_timestamp_micro_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = - Variant::from(DateTime::from_timestamp(1431648000, 0).expect("Expect TimeStamp")); - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Timestamp(µs, \"utc\")") - } - - #[test] - fn test_get_single_typed_decimal_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = Variant::Decimal4(VariantDecimal4::try_new(1234, 1).expect("Expect decimal")); - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Decimal32(4, 1)") - } - - #[test] - fn test_get_single_typed_binary_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = Variant::Binary(&[1u8, 2, 3]); - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Binary") - } - - #[test] - fn test_get_single_typed_time_variant_schema() { - let udf = VariantSchemaUDF::default(); - let variant = Variant::from(NaiveTime::from_hms_opt(0, 0, 0).expect("Expect NaiveTime")); - let variant_array = VariantArray::from_iter(vec![variant]); - let struct_array = variant_array.into_inner(); - let args = build_scalar_udf_args(struct_array); - let result = udf.invoke_with_args(args).unwrap(); - let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(schema))) = result else { - panic!() - }; - assert_eq!(schema, "Time64(µs)") - } - -} diff --git a/tests/test_files/variant_schema.slt b/tests/test_files/variant_schema.slt index 0eafdcb..4eeb064 100644 --- a/tests/test_files/variant_schema.slt +++ b/tests/test_files/variant_schema.slt @@ -1,4 +1,4 @@ -# tests the variant_schema udf +# tests the variant_schema udf # this function takes a VARIANT expression # and extracts each row's SQL schema @@ -28,7 +28,7 @@ SELECT variant_schema(json_to_variant('{"data": [{"a":"a"}, 5]}')) OBJECT> # typed literal -query T +query T SELECT variant_schema(json_to_variant(123.4)) ---- Float64 @@ -45,8 +45,34 @@ SELECT variant_schema(json_to_variant('true')) ---- Boolean -# TODO: add DATE/TIME/TIMESTAMP/DECIMAL slt coverage once cast_to_variant -# scalar/columnar return-field nullability mismatch is resolved for typed casts. +# cast_to_variant typed DATE/TIME/TIMESTAMP/DECIMAL from columns +statement ok +CREATE TABLE t_typed AS +SELECT + CAST('1990-01-01' AS DATE) AS d, + CAST('00:00:00' AS TIME) AS t, + CAST('2015-05-14 00:00:00' AS TIMESTAMP) AS ts, + CAST(123.4 AS DECIMAL(4, 1)) AS decv; + +query T +SELECT variant_schema(cast_to_variant(d)) FROM t_typed; +---- +Date32 + +query T +SELECT variant_schema(cast_to_variant(t)) FROM t_typed; +---- +Time64(µs) + +query T +SELECT variant_schema(cast_to_variant(ts)) FROM t_typed; +---- +Timestamp(µs) + +query T +SELECT variant_schema(cast_to_variant(decv)) FROM t_typed; +---- +Decimal128(4, 1) # explicit null query T From 786fb9d1fb012b79a07e970411bffa91587c585b Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 2 Mar 2026 13:46:44 -0500 Subject: [PATCH 23/32] variant_schema_udf: remove aggregate file and agg tests --- src/variant_schema.rs | 71 +++++-- src/variant_schema_agg.rs | 266 ------------------------ tests/test_files/variant_schema_agg.slt | 114 ---------- 3 files changed, 55 insertions(+), 396 deletions(-) delete mode 100644 src/variant_schema_agg.rs delete mode 100644 tests/test_files/variant_schema_agg.slt diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 37cc17a..3ccd955 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -8,6 +8,7 @@ use datafusion::{ }; use parquet_variant::Variant; use parquet_variant_compute::VariantArray; +use std::collections::btree_map::Entry; use std::collections::BTreeMap; use std::sync::Arc; @@ -267,32 +268,70 @@ fn merge_primitives(a: DataType, b: DataType) -> Option { /// Merges two inferred Variant schemas into a common schema. /// Returns VARIANT if no common schema can be determined. pub fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { - use VariantSchema::*; + let mut merged = a; + merge_variant_schema_from(&mut merged, &b); + merged +} - match (a, b) { - (Variant, _) | (_, Variant) => Variant, +pub fn merge_variant_schema_from(target: &mut VariantSchema, incoming: &VariantSchema) { + use VariantSchema::*; - (Primitive(DataType::Null), x) | (x, Primitive(DataType::Null)) => x, + if matches!(target, Variant) || matches!(incoming, Variant) { + *target = Variant; + return; + } - (Primitive(p1), Primitive(p2)) => { - merge_primitives(p1, p2).map(Primitive).unwrap_or(Variant) - } + if matches!(incoming, Primitive(DataType::Null)) { + return; + } - (Array(a), Array(b)) => Array(Box::new(merge_variant_schema(*a, *b))), + if matches!(target, Primitive(DataType::Null)) { + *target = incoming.clone(); + return; + } - (Object(mut a), Object(b)) => { - for (k, v_b) in b { - a.entry(k) - .and_modify(|v_a| *v_a = merge_variant_schema(v_a.clone(), v_b.clone())) - .or_insert(v_b); + match incoming { + Primitive(p2) => { + if let Primitive(p1) = target { + let merged = merge_primitives(p1.clone(), p2.clone()) + .map(Primitive) + .unwrap_or(Variant); + *target = merged; + } else { + *target = Variant; } - Object(a) } - - _ => Variant, + Array(b) => { + if let Array(a) = target { + merge_variant_schema_from(a.as_mut(), b.as_ref()); + } else { + *target = Variant; + } + } + Object(b) => { + if let Object(a) = target { + for (k, v_b) in b { + match a.entry(k.clone()) { + Entry::Occupied(mut occ) => merge_variant_schema_from(occ.get_mut(), v_b), + Entry::Vacant(vac) => { + vac.insert(v_b.clone()); + } + } + } + } else { + *target = Variant; + } + } + Variant => { + *target = Variant; + } } } +pub fn merge_variant_schema_into(target: &mut VariantSchema, incoming: VariantSchema) { + merge_variant_schema_from(target, &incoming); +} + /// Prints schema in a presentable manner pub fn print_schema(schema: &VariantSchema) -> String { match schema { diff --git a/src/variant_schema_agg.rs b/src/variant_schema_agg.rs deleted file mode 100644 index 4fe8841..0000000 --- a/src/variant_schema_agg.rs +++ /dev/null @@ -1,266 +0,0 @@ -use arrow::array::AsArray; -use arrow_schema::{DataType, Field, FieldRef}; -use datafusion::{ - error::Result, - logical_expr::{ - Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, - function::{AccumulatorArgs, StateFieldsArgs}, - utils::format_state_name, - }, - scalar::ScalarValue, -}; -use parquet_variant_compute::VariantArray; -use std::sync::Arc; - -use crate::{ - VariantSchema, merge_variant_schema, print_schema, schema_from_variant, - shared::try_parse_binary_columnar, -}; - -/// Aggregate schema inference for VARIANT values across rows. -/// -/// This function infers per-row schemas using `schema_from_variant` and merges -/// them into a single schema per group. -/// -/// Semantics: -/// - Input: one VARIANT expression -/// - Output: one schema string per aggregate group -/// - Row filtering should be done via SQL `FILTER (WHERE ...)` -/// -/// Use `variant_schema` for row-wise (non-aggregate) inference. -#[derive(Debug, Hash, PartialEq, Eq)] -pub struct VariantSchemaAggUDAF { - signature: Signature, -} - -impl Default for VariantSchemaAggUDAF { - fn default() -> Self { - Self { - signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), - } - } -} - -impl AggregateUDFImpl for VariantSchemaAggUDAF { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "variant_schema_agg" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Utf8View) - } - - fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let fields = vec![Arc::new(Field::new( - format_state_name(args.name, "variant_schema"), - DataType::Binary, - true, - ))]; - - Ok(fields - .into_iter() - .chain(args.ordering_fields.to_vec()) - .collect()) - } - - fn accumulator( - &self, - acc_args: datafusion::logical_expr::function::AccumulatorArgs, - ) -> Result> { - Ok(Box::new(VariantSchemaAccumulator::new(acc_args))) - } -} - -/// Accumulator state for `variant_schema_agg`. -#[derive(Debug)] -pub struct VariantSchemaAccumulator { - schema: VariantSchema, -} - -impl VariantSchemaAccumulator { - fn new(_acc_args: AccumulatorArgs) -> Self { - Self { - schema: VariantSchema::Primitive(DataType::Null), - } - } -} - -impl Accumulator for VariantSchemaAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Binary(Some( - self.schema.to_state_bytes(), - ))]) - } - - fn evaluate(&mut self) -> Result { - // Return the schema as a Utf8 representation - Ok(ScalarValue::Utf8View(Some(print_schema(&self.schema)))) - } - - fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> Result<()> { - if self.schema == VariantSchema::Variant { - return Ok(()); - } - - // We're assuming the input is an array of variants - for value in values { - // Ensure we are dealing with VariantArray and extract the variant values - let variant_array = VariantArray::try_new(value.as_struct())?; - for variant in variant_array.iter().flatten() { - let new_schema = schema_from_variant(&variant); - // Merge the new schema with the current schema - self.schema = merge_variant_schema(self.schema.clone(), new_schema); - if self.schema == VariantSchema::Variant { - return Ok(()); - } - } - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> Result<()> { - if self.schema == VariantSchema::Variant { - return Ok(()); - } - - for state in states { - for encoded_state in try_parse_binary_columnar(state)?.into_iter().flatten() { - let new_schema = VariantSchema::from_state_bytes(encoded_state)?; - self.schema = merge_variant_schema(self.schema.clone(), new_schema); - if self.schema == VariantSchema::Variant { - return Ok(()); - } - } - } - Ok(()) - } - - fn size(&self) -> usize { - // The size is essentially the number of variants processed, if needed - 1 // This could be expanded to return a more useful size - } -} - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use arrow::array::ArrayRef; - use arrow_schema::{DataType, Field, Fields, Schema}; - use datafusion::{ - logical_expr::{Accumulator, function::AccumulatorArgs}, - physical_expr::PhysicalSortExpr, - physical_plan::expressions::col, - scalar::ScalarValue, - }; - use parquet_variant_compute::VariantType; - - use crate::{ - shared::build_variant_array_from_json_array, variant_schema_agg::VariantSchemaAccumulator, - }; - - #[test] - fn test_merge_batch_from_state_roundtrip() { - let schema = Schema::new(vec![ - Field::new( - "b", - DataType::Struct(Fields::from(vec![ - Field::new("metadata", DataType::Binary, true), - Field::new("value", DataType::Binary, true), - ])), - true, - ) - .with_extension_type(VariantType), - ]); - - let b1 = build_variant_array_from_json_array(&[Some(serde_json::json!({"a": 1}))]); - let b1: ArrayRef = Arc::new(b1.into_inner()); - - let b2 = build_variant_array_from_json_array(&[Some(serde_json::json!({"a": 2.5}))]); - let b2: ArrayRef = Arc::new(b2.into_inner()); - - let expr = col("b", &schema).unwrap(); - let order_bys = vec![PhysicalSortExpr::new_default(Arc::clone(&expr))]; - let exprs = vec![expr]; - let expr_fields = vec![Arc::new( - Field::new( - "b", - DataType::Struct(Fields::from(vec![ - Field::new("metadata", DataType::Binary, true), - Field::new("value", DataType::Binary, true), - ])), - true, - ) - .with_extension_type(VariantType), - )]; - - let acc1_args = AccumulatorArgs { - return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), - schema: &schema, - ignore_nulls: false, - order_bys: &order_bys, - is_reversed: false, - name: "variant_schema_agg", - is_distinct: false, - exprs: &exprs, - expr_fields: &expr_fields, - }; - let acc2_args = AccumulatorArgs { - return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), - schema: &schema, - ignore_nulls: false, - order_bys: &order_bys, - is_reversed: false, - name: "variant_schema_agg", - is_distinct: false, - exprs: &exprs, - expr_fields: &expr_fields, - }; - let merged_args = AccumulatorArgs { - return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), - schema: &schema, - ignore_nulls: false, - order_bys: &order_bys, - is_reversed: false, - name: "variant_schema_agg", - is_distinct: false, - exprs: &exprs, - expr_fields: &expr_fields, - }; - - let mut acc1 = VariantSchemaAccumulator::new(acc1_args); - acc1.update_batch(&[Arc::clone(&b1)]).unwrap(); - let state_1 = acc1 - .state() - .unwrap() - .into_iter() - .map(|s| s.to_array().unwrap()) - .collect::>(); - - let mut acc2 = VariantSchemaAccumulator::new(acc2_args); - acc2.update_batch(&[Arc::clone(&b2)]).unwrap(); - let state_2 = acc2 - .state() - .unwrap() - .into_iter() - .map(|s| s.to_array().unwrap()) - .collect::>(); - - let mut merged = VariantSchemaAccumulator::new(merged_args); - merged.merge_batch(&state_1).unwrap(); - merged.merge_batch(&state_2).unwrap(); - - assert_eq!( - merged.evaluate().unwrap(), - ScalarValue::Utf8View(Some("OBJECT".to_string())) - ); - } -} diff --git a/tests/test_files/variant_schema_agg.slt b/tests/test_files/variant_schema_agg.slt deleted file mode 100644 index 4677a97..0000000 --- a/tests/test_files/variant_schema_agg.slt +++ /dev/null @@ -1,114 +0,0 @@ -# tests the variant_schema_agg udaf -# this function takes a Variant Array -# and extracts it's SQL schema - -# same schema -statement ok -CREATE TABLE t as VALUES -(json_to_variant('{"foo": "bar", "wing": {"ding": "dong"}}')), -(json_to_variant('{"wing": {"ding": "man"}}')); - -query T -SELECT variant_schema_agg(column1) from t; ----- -OBJECT> - -# conflicting schema -statement ok -CREATE TABLE t_conflicting as VALUES -(json_to_variant('{"foo": "bar", "wing": {"ding": "dong"}}')), -(json_to_variant('{"wing": 123}')); - -query T -SELECT variant_schema_agg(column1) from t_conflicting; ----- -OBJECT - -# null row -statement ok -CREATE TABLE t_nulls AS VALUES -(json_to_variant('{"a": 1}')), -(json_to_variant('null')), -(json_to_variant('{"a": 2}')); - -query T -SELECT variant_schema_agg(column1) FROM t_nulls; ----- -OBJECT - -# numeric widening -statement ok -CREATE TABLE t_nums AS VALUES -(json_to_variant('{"a": 1}')), -(json_to_variant('{"a": 2.5}')); - -query T -SELECT variant_schema_agg(column1) FROM t_nums; ----- -OBJECT - -# field appears later -statement ok -CREATE TABLE t_sparse AS VALUES -(json_to_variant('{}')), -(json_to_variant('{"a": 1}')); - -query T -SELECT variant_schema_agg(column1) FROM t_sparse; ----- -OBJECT - -# conflicting array of objects -statement ok -CREATE TABLE t_arr_objs AS VALUES -(json_to_variant('[{"a":1}]')), -(json_to_variant('[{"a":"x"}]')); - -query T -SELECT variant_schema_agg(column1) FROM t_arr_objs; ----- -ARRAY> - -# empty aggregates -statement ok -CREATE TABLE t_empty AS VALUES -(json_to_variant('{}')), -(json_to_variant('{}')); - -query T -SELECT variant_schema_agg(column1) FROM t_empty; ----- -OBJECT<> - -# field ordering -statement ok -CREATE TABLE t_order AS VALUES -(json_to_variant('{"b":1}')), -(json_to_variant('{"a":2}')); - -query T -SELECT variant_schema_agg(column1) FROM t_order; ----- -OBJECT - -# root conflict -statement ok -CREATE TABLE t_root_conflict AS VALUES -(json_to_variant('{"a":1}')), -(json_to_variant('[1,2,3]')); - -query T -SELECT variant_schema_agg(column1) FROM t_root_conflict; ----- -VARIANT - -# mixed root -statement ok -CREATE TABLE t_mixed AS VALUES -(json_to_variant('1')), -(json_to_variant('{"a": 1}')); - -query T -SELECT variant_schema_agg(column1) FROM t_mixed; ----- -VARIANT \ No newline at end of file From cb1c92d1879464b2300e811502f8d0f36a8e8ea2 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 2 Mar 2026 13:49:05 -0500 Subject: [PATCH 24/32] remove udaf from slt --- tests/sqllogictests.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/sqllogictests.rs b/tests/sqllogictests.rs index 02af811..9fa0f59 100644 --- a/tests/sqllogictests.rs +++ b/tests/sqllogictests.rs @@ -60,7 +60,6 @@ async fn run_sqllogictests() -> Result<(), Box> { ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectInsert::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectDelete::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantSchemaUDF::default())); - ctx.register_udaf(AggregateUDF::new_from_impl(VariantSchemaAggUDAF::default())); let pb = ProgressBar::new(24); From 1b901102ede38252b3cd1c04657962ceac5cac92 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 2 Mar 2026 13:53:30 -0500 Subject: [PATCH 25/32] remove other mentions of udaf --- src/lib.rs | 2 -- src/variant_schema.rs | 19 +++++++++---------- tests/sqllogictests.rs | 7 ++----- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index cbc3787..f24650f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,6 @@ mod variant_object_delete; mod variant_object_insert; mod variant_pretty; mod variant_schema; -mod variant_schema_agg; mod variant_to_json; pub use cast_to_variant::*; @@ -29,5 +28,4 @@ pub use variant_object_delete::*; pub use variant_object_insert::*; pub use variant_pretty::*; pub use variant_schema::*; -pub use variant_schema_agg::*; pub use variant_to_json::*; diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 3ccd955..06e4ba1 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -36,8 +36,7 @@ impl Default for VariantSchemaUDF { /// Execution semantics: /// - Scalar input: infer one schema string for that value. /// - Columnar input: infer one schema string per row (vectorized row-wise behavior). -/// - This function does not merge schemas across rows. For cross-row/group merge use -/// `variant_schema_agg`. +/// - This function does not merge schemas across rows. /// /// Merge rules (within one VARIANT value only): /// - If outer (or inner) kinds differ, the result is `VARIANT` @@ -63,7 +62,7 @@ impl VariantSchema { let mut offset = 0usize; let decoded = decode_variant_schema(bytes, &mut offset)?; if offset != bytes.len() { - return exec_err!("invalid variant_schema_agg state: trailing bytes"); + return exec_err!("invalid variant_schema state: trailing bytes"); } Ok(decoded) } @@ -76,7 +75,7 @@ fn encode_len_prefixed_bytes(out: &mut Vec, bytes: &[u8]) { fn read_u8(input: &[u8], offset: &mut usize) -> Result { let Some(v) = input.get(*offset) else { - return exec_err!("invalid variant_schema_agg state: missing tag"); + return exec_err!("invalid variant_schema state: missing tag"); }; *offset += 1; Ok(*v) @@ -84,7 +83,7 @@ fn read_u8(input: &[u8], offset: &mut usize) -> Result { fn read_u32(input: &[u8], offset: &mut usize) -> Result { let Some(raw) = input.get(*offset..(*offset + 4)) else { - return exec_err!("invalid variant_schema_agg state: missing u32"); + return exec_err!("invalid variant_schema state: missing u32"); }; *offset += 4; Ok(u32::from_le_bytes([raw[0], raw[1], raw[2], raw[3]])) @@ -93,7 +92,7 @@ fn read_u32(input: &[u8], offset: &mut usize) -> Result { fn read_len_prefixed_bytes<'a>(input: &'a [u8], offset: &mut usize) -> Result<&'a [u8]> { let len = read_u32(input, offset)? as usize; let Some(raw) = input.get(*offset..(*offset + len)) else { - return exec_err!("invalid variant_schema_agg state: truncated payload"); + return exec_err!("invalid variant_schema state: truncated payload"); }; *offset += len; Ok(raw) @@ -127,11 +126,11 @@ fn decode_variant_schema(input: &[u8], offset: &mut usize) -> Result v, - Err(e) => return exec_err!("invalid variant_schema_agg state: {e}"), + Err(e) => return exec_err!("invalid variant_schema state: {e}"), }; let dtype = match dtype_str.parse::() { Ok(v) => v, - Err(e) => return exec_err!("invalid variant_schema_agg datatype state: {e}"), + Err(e) => return exec_err!("invalid variant_schema datatype state: {e}"), }; Ok(VariantSchema::Primitive(dtype)) } @@ -145,7 +144,7 @@ fn decode_variant_schema(input: &[u8], offset: &mut usize) -> Result v.to_string(), - Err(e) => return exec_err!("invalid variant_schema_agg field key: {e}"), + Err(e) => return exec_err!("invalid variant_schema field key: {e}"), }; let value = decode_variant_schema(input, offset)?; fields.insert(key, value); @@ -153,7 +152,7 @@ fn decode_variant_schema(input: &[u8], offset: &mut usize) -> Result Ok(VariantSchema::Variant), - tag => exec_err!("invalid variant_schema_agg state tag: {tag}"), + tag => exec_err!("invalid variant_schema state tag: {tag}"), } } diff --git a/tests/sqllogictests.rs b/tests/sqllogictests.rs index 9fa0f59..7faecff 100644 --- a/tests/sqllogictests.rs +++ b/tests/sqllogictests.rs @@ -1,12 +1,9 @@ -use datafusion::{ - logical_expr::{AggregateUDF, ScalarUDF}, - prelude::*, -}; +use datafusion::{logical_expr::ScalarUDF, prelude::*}; use datafusion_sqllogictest::{DataFusion, TestContext}; use datafusion_variant::{ CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetUdf, VariantListConstruct, VariantListDelete, VariantListInsert, VariantObjectConstruct, VariantObjectDelete, - VariantObjectInsert, VariantPretty, VariantSchemaAggUDAF, VariantSchemaUDF, VariantToJsonUdf, + VariantObjectInsert, VariantPretty, VariantSchemaUDF, VariantToJsonUdf, }; use indicatif::ProgressBar; use sqllogictest::strict_column_validator; From 7d6237f821420b67119634d37e86dca4cd21f20a Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 2 Mar 2026 14:13:10 -0500 Subject: [PATCH 26/32] fmt --- src/variant_schema.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 06e4ba1..b1d3381 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -8,8 +8,8 @@ use datafusion::{ }; use parquet_variant::Variant; use parquet_variant_compute::VariantArray; -use std::collections::btree_map::Entry; use std::collections::BTreeMap; +use std::collections::btree_map::Entry; use std::sync::Arc; #[derive(Debug, Hash, PartialEq, Eq)] From f95f4eba64079f04c9eea2f123b5c99d0101a8e2 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Mon, 2 Mar 2026 19:17:59 -0500 Subject: [PATCH 27/32] use json instead of custom encoding --- Cargo.toml | 2 +- src/variant_schema.rs | 179 +++++++++++++++++++++++++----------------- 2 files changed, 110 insertions(+), 71 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index afd412c..de56a16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,11 +10,11 @@ arrow-schema = "=57.2.0" parquet-variant-compute = "=57.2.0" parquet-variant-json = "=57.2.0" parquet-variant = "=57.2.0" +serde_json = "1.0.145" [dev-dependencies] anyhow = "1.0.100" arrow-cast = "=57.2.0" -serde_json = "1.0.145" flate2 = "1.0" tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] } datafusion-sqllogictest = "52.1.0" diff --git a/src/variant_schema.rs b/src/variant_schema.rs index b1d3381..0dd75d1 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -8,6 +8,7 @@ use datafusion::{ }; use parquet_variant::Variant; use parquet_variant_compute::VariantArray; +use serde_json::{Map, Value}; use std::collections::BTreeMap; use std::collections::btree_map::Entry; use std::sync::Arc; @@ -53,106 +54,117 @@ pub enum VariantSchema { impl VariantSchema { pub fn to_state_bytes(&self) -> Vec { - let mut out = Vec::new(); - encode_variant_schema(self, &mut out); - out + self.to_state_string().into_bytes() } - pub fn from_state_bytes(bytes: &[u8]) -> Result { - let mut offset = 0usize; - let decoded = decode_variant_schema(bytes, &mut offset)?; - if offset != bytes.len() { - return exec_err!("invalid variant_schema state: trailing bytes"); - } - Ok(decoded) + pub fn to_state_string(&self) -> String { + schema_to_json(self).to_string() } -} - -fn encode_len_prefixed_bytes(out: &mut Vec, bytes: &[u8]) { - out.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); - out.extend_from_slice(bytes); -} - -fn read_u8(input: &[u8], offset: &mut usize) -> Result { - let Some(v) = input.get(*offset) else { - return exec_err!("invalid variant_schema state: missing tag"); - }; - *offset += 1; - Ok(*v) -} -fn read_u32(input: &[u8], offset: &mut usize) -> Result { - let Some(raw) = input.get(*offset..(*offset + 4)) else { - return exec_err!("invalid variant_schema state: missing u32"); - }; - *offset += 4; - Ok(u32::from_le_bytes([raw[0], raw[1], raw[2], raw[3]])) -} + pub fn from_state_bytes(bytes: &[u8]) -> Result { + let state = match std::str::from_utf8(bytes) { + Ok(v) => v, + Err(e) => return exec_err!("invalid variant_schema utf8 state: {e}"), + }; + Self::from_state_str(state) + } -fn read_len_prefixed_bytes<'a>(input: &'a [u8], offset: &mut usize) -> Result<&'a [u8]> { - let len = read_u32(input, offset)? as usize; - let Some(raw) = input.get(*offset..(*offset + len)) else { - return exec_err!("invalid variant_schema state: truncated payload"); - }; - *offset += len; - Ok(raw) + pub fn from_state_str(state: &str) -> Result { + let value = match serde_json::from_str::(state) { + Ok(v) => v, + Err(e) => return exec_err!("invalid variant_schema json state: {e}"), + }; + schema_from_json(&value) + } } -fn encode_variant_schema(schema: &VariantSchema, out: &mut Vec) { +fn schema_to_json(schema: &VariantSchema) -> Value { match schema { VariantSchema::Primitive(dtype) => { - out.push(0); - encode_len_prefixed_bytes(out, dtype.to_string().as_bytes()); + let mut node = Map::new(); + node.insert("kind".to_string(), Value::String("primitive".to_string())); + node.insert("dtype".to_string(), Value::String(dtype.to_string())); + Value::Object(node) } VariantSchema::Array(inner) => { - out.push(1); - encode_variant_schema(inner, out); + let mut node = Map::new(); + node.insert("kind".to_string(), Value::String("array".to_string())); + node.insert("inner".to_string(), schema_to_json(inner)); + Value::Object(node) } VariantSchema::Object(fields) => { - out.push(2); - out.extend_from_slice(&(fields.len() as u32).to_le_bytes()); + let mut field_map = Map::new(); for (key, value) in fields { - encode_len_prefixed_bytes(out, key.as_bytes()); - encode_variant_schema(value, out); + field_map.insert(key.clone(), schema_to_json(value)); } + + let mut node = Map::new(); + node.insert("kind".to_string(), Value::String("object".to_string())); + node.insert("fields".to_string(), Value::Object(field_map)); + Value::Object(node) + } + VariantSchema::Variant => { + let mut node = Map::new(); + node.insert("kind".to_string(), Value::String("variant".to_string())); + Value::Object(node) } - VariantSchema::Variant => out.push(3), } } -fn decode_variant_schema(input: &[u8], offset: &mut usize) -> Result { - match read_u8(input, offset)? { - 0 => { - let raw = read_len_prefixed_bytes(input, offset)?; - let dtype_str = match std::str::from_utf8(raw) { - Ok(v) => v, - Err(e) => return exec_err!("invalid variant_schema state: {e}"), +fn schema_from_json(value: &Value) -> Result { + let obj = match value { + Value::Object(obj) => obj, + _ => return exec_err!("invalid variant_schema state: expected object"), + }; + + let kind = match obj.get("kind") { + Some(Value::String(v)) => v.as_str(), + _ => return exec_err!("invalid variant_schema state: missing or invalid `kind`"), + }; + + match kind { + "primitive" => { + let dtype_str = match obj.get("dtype") { + Some(Value::String(v)) => v, + _ => { + return exec_err!( + "invalid variant_schema primitive state: missing or invalid `dtype`" + ); + } }; + let dtype = match dtype_str.parse::() { Ok(v) => v, Err(e) => return exec_err!("invalid variant_schema datatype state: {e}"), }; Ok(VariantSchema::Primitive(dtype)) } - 1 => Ok(VariantSchema::Array(Box::new(decode_variant_schema( - input, offset, - )?))), - 2 => { - let count = read_u32(input, offset)? as usize; + "array" => { + let inner = match obj.get("inner") { + Some(v) => v, + None => return exec_err!("invalid variant_schema array state: missing `inner`"), + }; + Ok(VariantSchema::Array(Box::new(schema_from_json(inner)?))) + } + "object" => { + let fields_obj = match obj.get("fields") { + Some(Value::Object(v)) => v, + _ => { + return exec_err!( + "invalid variant_schema object state: missing or invalid `fields`" + ); + } + }; + let mut fields = BTreeMap::new(); - for _ in 0..count { - let key_raw = read_len_prefixed_bytes(input, offset)?; - let key = match std::str::from_utf8(key_raw) { - Ok(v) => v.to_string(), - Err(e) => return exec_err!("invalid variant_schema field key: {e}"), - }; - let value = decode_variant_schema(input, offset)?; - fields.insert(key, value); + for (field_name, field_value) in fields_obj { + fields.insert(field_name.clone(), schema_from_json(field_value)?); } + Ok(VariantSchema::Object(fields)) } - 3 => Ok(VariantSchema::Variant), - tag => exec_err!("invalid variant_schema state tag: {tag}"), + "variant" => Ok(VariantSchema::Variant), + other => exec_err!("invalid variant_schema state kind: {other}"), } } @@ -408,3 +420,30 @@ impl ScalarUDFImpl for VariantSchemaUDF { infer_variant_schema(arg) } } + +#[cfg(test)] +mod tests { + use super::{VariantSchema, print_schema}; + use arrow_schema::DataType; + use std::collections::BTreeMap; + + #[test] + fn state_round_trip_uses_utf8_json() { + let mut fields = BTreeMap::new(); + fields.insert( + "a:key,with>delims".to_string(), + VariantSchema::Array(Box::new(VariantSchema::Primitive(DataType::Int64))), + ); + + let schema = VariantSchema::Object(fields); + let bytes = schema.to_state_bytes(); + + let text = std::str::from_utf8(&bytes).expect("state should be utf8"); + assert!(text.contains("\"kind\":\"object\"")); + assert!(text.contains("\"fields\"")); + + let decoded = VariantSchema::from_state_bytes(&bytes).expect("round-trip decode"); + assert_eq!(decoded, schema); + assert_eq!(print_schema(&decoded), print_schema(&schema)); + } +} From c3dc5e4a30511254322ef8be4031806ea686d1af Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Tue, 3 Mar 2026 15:25:05 -0500 Subject: [PATCH 28/32] first round feedback changes --- Cargo.toml | 1 + src/variant_schema.rs | 223 +++++++++++++++++++++--------------------- 2 files changed, 111 insertions(+), 113 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 312b281..2a51c5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ arrow-schema = "57" parquet-variant-compute = "57" parquet-variant-json = "57" parquet-variant = "57" +serde_json = "1.0.145" [dev-dependencies] anyhow = "1.0.100" diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 0dd75d1..68f4c10 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -8,7 +8,7 @@ use datafusion::{ }; use parquet_variant::Variant; use parquet_variant_compute::VariantArray; -use serde_json::{Map, Value}; +use serde_json::Value; use std::collections::BTreeMap; use std::collections::btree_map::Entry; use std::sync::Arc; @@ -62,140 +62,137 @@ impl VariantSchema { } pub fn from_state_bytes(bytes: &[u8]) -> Result { - let state = match std::str::from_utf8(bytes) { - Ok(v) => v, - Err(e) => return exec_err!("invalid variant_schema utf8 state: {e}"), - }; + let state = std::str::from_utf8(bytes).map_err(|e| { + DataFusionError::Execution(format!("invalid variant_schema utf8 state: {e}")) + })?; Self::from_state_str(state) } pub fn from_state_str(state: &str) -> Result { - let value = match serde_json::from_str::(state) { - Ok(v) => v, - Err(e) => return exec_err!("invalid variant_schema json state: {e}"), - }; - schema_from_json(&value) + let value = serde_json::from_str::(state).map_err(|e| { + DataFusionError::Execution(format!("invalid variant_schema json state: {e}")) + })?; + Self::try_from(&value) } } -fn schema_to_json(schema: &VariantSchema) -> Value { - match schema { - VariantSchema::Primitive(dtype) => { - let mut node = Map::new(); - node.insert("kind".to_string(), Value::String("primitive".to_string())); - node.insert("dtype".to_string(), Value::String(dtype.to_string())); - Value::Object(node) - } - VariantSchema::Array(inner) => { - let mut node = Map::new(); - node.insert("kind".to_string(), Value::String("array".to_string())); - node.insert("inner".to_string(), schema_to_json(inner)); - Value::Object(node) - } - VariantSchema::Object(fields) => { - let mut field_map = Map::new(); - for (key, value) in fields { - field_map.insert(key.clone(), schema_to_json(value)); +impl From<&Variant<'_, '_>> for VariantSchema { + fn from(value: &Variant) -> Self { + match value { + Variant::Object(obj) => { + let fields = obj + .iter() + .map(|(k, v)| (k.to_string(), Self::from(&v))) + .collect(); + + VariantSchema::Object(fields) } + Variant::List(list) => { + let inner = list + .iter() + .map(|v| Self::from(&v)) + .try_fold(VariantSchema::Primitive(DataType::Null), |acc, next| { + let merged = merge_variant_schema(acc, next); + if merged == VariantSchema::Variant { + Err(merged) + } else { + Ok(merged) + } + }) + .unwrap_or_else(|schema| schema); - let mut node = Map::new(); - node.insert("kind".to_string(), Value::String("object".to_string())); - node.insert("fields".to_string(), Value::Object(field_map)); - Value::Object(node) - } - VariantSchema::Variant => { - let mut node = Map::new(); - node.insert("kind".to_string(), Value::String("variant".to_string())); - Value::Object(node) + VariantSchema::Array(Box::new(inner)) + } + _ => VariantSchema::Primitive(primitive_from_variant(value)), } } } -fn schema_from_json(value: &Value) -> Result { - let obj = match value { - Value::Object(obj) => obj, - _ => return exec_err!("invalid variant_schema state: expected object"), - }; - - let kind = match obj.get("kind") { - Some(Value::String(v)) => v.as_str(), - _ => return exec_err!("invalid variant_schema state: missing or invalid `kind`"), - }; - - match kind { - "primitive" => { - let dtype_str = match obj.get("dtype") { - Some(Value::String(v)) => v, - _ => { - return exec_err!( - "invalid variant_schema primitive state: missing or invalid `dtype`" - ); - } - }; +impl TryFrom<&Value> for VariantSchema { + type Error = DataFusionError; - let dtype = match dtype_str.parse::() { - Ok(v) => v, - Err(e) => return exec_err!("invalid variant_schema datatype state: {e}"), - }; - Ok(VariantSchema::Primitive(dtype)) - } - "array" => { - let inner = match obj.get("inner") { - Some(v) => v, - None => return exec_err!("invalid variant_schema array state: missing `inner`"), - }; - Ok(VariantSchema::Array(Box::new(schema_from_json(inner)?))) - } - "object" => { - let fields_obj = match obj.get("fields") { - Some(Value::Object(v)) => v, - _ => { - return exec_err!( - "invalid variant_schema object state: missing or invalid `fields`" - ); - } - }; + fn try_from(value: &Value) -> std::result::Result { + let obj = match value { + Value::Object(obj) => obj, + _ => return exec_err!("invalid variant_schema state: expected object"), + }; + + let kind = match obj.get("kind") { + Some(Value::String(v)) => v.as_str(), + _ => return exec_err!("invalid variant_schema state: missing or invalid `kind`"), + }; + + match kind { + "primitive" => { + let dtype_str = match obj.get("dtype") { + Some(Value::String(v)) => v, + _ => { + return exec_err!( + "invalid variant_schema primitive state: missing or invalid `dtype`" + ); + } + }; - let mut fields = BTreeMap::new(); - for (field_name, field_value) in fields_obj { - fields.insert(field_name.clone(), schema_from_json(field_value)?); + let dtype = match dtype_str.parse::() { + Ok(v) => v, + Err(e) => return exec_err!("invalid variant_schema datatype state: {e}"), + }; + Ok(VariantSchema::Primitive(dtype)) } + "array" => { + let inner = match obj.get("inner") { + Some(v) => v, + None => { + return exec_err!("invalid variant_schema array state: missing `inner`"); + } + }; + Ok(VariantSchema::Array(Box::new(Self::try_from(inner)?))) + } + "object" => { + let fields_obj = match obj.get("fields") { + Some(Value::Object(v)) => v, + _ => { + return exec_err!( + "invalid variant_schema object state: missing or invalid `fields`" + ); + } + }; + + let mut fields = BTreeMap::new(); + for (field_name, field_value) in fields_obj { + fields.insert(field_name.clone(), Self::try_from(field_value)?); + } - Ok(VariantSchema::Object(fields)) + Ok(VariantSchema::Object(fields)) + } + "variant" => Ok(VariantSchema::Variant), + other => exec_err!("invalid variant_schema state kind: {other}"), } - "variant" => Ok(VariantSchema::Variant), - other => exec_err!("invalid variant_schema state kind: {other}"), } } -/// This function extracts the schema from a single Variant scalar -pub fn schema_from_variant(v: &Variant) -> VariantSchema { - match v { - Variant::Object(obj) => { - let fields = obj - .iter() - .map(|(k, v)| (k.to_string(), schema_from_variant(&v))) - .collect(); - - VariantSchema::Object(fields) - } - Variant::List(list) => { - let inner = list +fn schema_to_json(schema: &VariantSchema) -> Value { + match schema { + VariantSchema::Primitive(dtype) => serde_json::json!({ + "kind": "primitive", + "dtype": dtype.to_string() + }), + VariantSchema::Array(inner) => serde_json::json!({ + "kind": "array", + "inner": schema_to_json(inner) + }), + VariantSchema::Object(fields) => { + let fields_json = fields .iter() - .map(|v| schema_from_variant(&v)) - .try_fold(VariantSchema::Primitive(DataType::Null), |acc, next| { - let merged = merge_variant_schema(acc, next); - if merged == VariantSchema::Variant { - Err(merged) - } else { - Ok(merged) - } - }) - .unwrap_or_else(|schema| schema); + .map(|(k, v)| (k.clone(), schema_to_json(v))) + .collect::>(); - VariantSchema::Array(Box::new(inner)) + serde_json::json!({ + "kind": "object", + "fields": fields_json + }) } - _ => VariantSchema::Primitive(primitive_from_variant(v)), + VariantSchema::Variant => serde_json::json!({ "kind": "variant" }), } } @@ -374,7 +371,7 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { let variant_array = VariantArray::try_new(struct_array.as_ref())?; let v = variant_array.value(0); - let schema = schema_from_variant(&v); + let schema = VariantSchema::from(&v); Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( print_schema(&schema), @@ -384,7 +381,7 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { let variant_array = VariantArray::try_new(array.as_ref())?; let out = variant_array .iter() - .map(|v| v.map(|v| print_schema(&schema_from_variant(&v)))) + .map(|v| v.map(|v| print_schema(&VariantSchema::from(&v)))) .collect::>(); let out: StringViewArray = out.into(); From c782d4c9efecbf55d4d6a5b12fa27bed1dc56ecd Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Tue, 3 Mar 2026 15:40:01 -0500 Subject: [PATCH 29/32] second round of addressing review --- src/variant_schema.rs | 67 ++++++++++++----------------- tests/test_files/variant_schema.slt | 6 --- 2 files changed, 27 insertions(+), 46 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 68f4c10..65498b3 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -26,6 +26,32 @@ impl Default for VariantSchemaUDF { } } +impl ScalarUDFImpl for VariantSchemaUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "variant_schema" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8View) + } + + fn invoke_with_args( + &self, + args: datafusion::logical_expr::ScalarFunctionArgs, + ) -> Result { + let arg = &args.args[0]; + infer_variant_schema(arg) + } +} + /// Infers a schema description for one VARIANT value. /// /// The inferred schema can be one of four logical forms: @@ -254,21 +280,10 @@ fn primitive_from_variant<'m, 'v>(v: &Variant<'m, 'v>) -> DataType { /// are different /// /// Todo: needs more work on type coercing -/// - add decimal coercion rules +/// docs.databricks.com/aws/en/sql/language-manual/sql-ref-datatype-rules#type-precedence-list fn merge_primitives(a: DataType, b: DataType) -> Option { - use DataType::*; - match (a, b) { (x, y) if x == y => Some(x), - // numeric widening - // docs.databricks.com/aws/en/sql/language-manual/sql-ref-datatype-rules#type-precedence-list - // For least common type resolution FLOAT is skipped to avoid loss of precision. - (Int8 | Int16 | Int32 | Int64 | Float32, Float64) - | (Float64, Int8 | Int16 | Int32 | Int64 | Float32) => Some(Float64), - (Int8 | Int16 | Int32, Int64) | (Int64, Int8 | Int16 | Int32) => Some(Int64), - (Int8 | Int16, Int32) | (Int32, Int8 | Int16) => Some(Int32), - (Date32, Timestamp(tu, tz)) | (Timestamp(tu, tz), Date32) => Some(Timestamp(tu, tz)), - _ => None, } } @@ -390,34 +405,6 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { } } -impl ScalarUDFImpl for VariantSchemaUDF { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "variant_schema" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Utf8View) - } - - fn invoke_with_args( - &self, - args: datafusion::logical_expr::ScalarFunctionArgs, - ) -> Result { - let arg = args.args.first().ok_or_else(|| { - DataFusionError::Execution("empty argument, expected 1 argument".to_string()) - })?; - infer_variant_schema(arg) - } -} - #[cfg(test)] mod tests { use super::{VariantSchema, print_schema}; diff --git a/tests/test_files/variant_schema.slt b/tests/test_files/variant_schema.slt index 4eeb064..5659bd8 100644 --- a/tests/test_files/variant_schema.slt +++ b/tests/test_files/variant_schema.slt @@ -86,12 +86,6 @@ SELECT variant_schema(json_to_variant('{"a": null}')) ---- OBJECT -# numeric widening -query T -SELECT variant_schema(json_to_variant('[1, 2.5, 3]')) ----- -ARRAY - # array of objects query T SELECT variant_schema(json_to_variant('[{"a":1},{"a":2}]')) From afc807005fdbc932f265e99c72bacfbc816d7629 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Tue, 3 Mar 2026 19:05:10 -0500 Subject: [PATCH 30/32] remove redundant function --- src/variant_schema.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index 65498b3..ee973fd 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -351,10 +351,6 @@ pub fn merge_variant_schema_from(target: &mut VariantSchema, incoming: &VariantS } } -pub fn merge_variant_schema_into(target: &mut VariantSchema, incoming: VariantSchema) { - merge_variant_schema_from(target, &incoming); -} - /// Prints schema in a presentable manner pub fn print_schema(schema: &VariantSchema) -> String { match schema { From c012a467675b083848342bc89c5d25d521f27997 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Tue, 3 Mar 2026 19:37:05 -0500 Subject: [PATCH 31/32] impl std::fmt::Display --- src/variant_schema.rs | 98 ++++++++++++++++++++++++------------------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index ee973fd..c32e414 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -78,27 +78,38 @@ pub enum VariantSchema { Variant, } -impl VariantSchema { - pub fn to_state_bytes(&self) -> Vec { - self.to_state_string().into_bytes() +impl std::fmt::Display for VariantSchema { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt_schema(self, f) } +} - pub fn to_state_string(&self) -> String { - schema_to_json(self).to_string() - } +/// Prints schema in a presentable manner +pub fn fmt_schema(schema: &VariantSchema, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match schema { + VariantSchema::Primitive(dt) => write!(f, "{dt}"), - pub fn from_state_bytes(bytes: &[u8]) -> Result { - let state = std::str::from_utf8(bytes).map_err(|e| { - DataFusionError::Execution(format!("invalid variant_schema utf8 state: {e}")) - })?; - Self::from_state_str(state) - } + VariantSchema::Variant => f.write_str("VARIANT"), - pub fn from_state_str(state: &str) -> Result { - let value = serde_json::from_str::(state).map_err(|e| { - DataFusionError::Execution(format!("invalid variant_schema json state: {e}")) - })?; - Self::try_from(&value) + VariantSchema::Array(inner) => { + f.write_str("ARRAY<")?; + fmt_schema(inner, f)?; + f.write_str(">") + } + + VariantSchema::Object(fields) => { + f.write_str("OBJECT<")?; + let mut first = true; + for (k, v) in fields { + if !first { + f.write_str(", ")?; + } + first = false; + write!(f, "{k}: ")?; + fmt_schema(v, f)?; + } + f.write_str(">") + } } } @@ -197,6 +208,30 @@ impl TryFrom<&Value> for VariantSchema { } } +impl VariantSchema { + pub fn to_state_bytes(&self) -> Vec { + self.to_state_string().into_bytes() + } + + pub fn to_state_string(&self) -> String { + schema_to_json(self).to_string() + } + + pub fn from_state_bytes(bytes: &[u8]) -> Result { + let state = std::str::from_utf8(bytes).map_err(|e| { + DataFusionError::Execution(format!("invalid variant_schema utf8 state: {e}")) + })?; + Self::from_state_str(state) + } + + pub fn from_state_str(state: &str) -> Result { + let value = serde_json::from_str::(state).map_err(|e| { + DataFusionError::Execution(format!("invalid variant_schema json state: {e}")) + })?; + Self::try_from(&value) + } +} + fn schema_to_json(schema: &VariantSchema) -> Value { match schema { VariantSchema::Primitive(dtype) => serde_json::json!({ @@ -351,27 +386,6 @@ pub fn merge_variant_schema_from(target: &mut VariantSchema, incoming: &VariantS } } -/// Prints schema in a presentable manner -pub fn print_schema(schema: &VariantSchema) -> String { - match schema { - VariantSchema::Primitive(s) => format!("{s}"), - - VariantSchema::Variant => "VARIANT".to_string(), - - VariantSchema::Array(inner) => { - format!("ARRAY<{}>", print_schema(inner)) - } - - VariantSchema::Object(fields) => { - let parts: Vec = fields - .iter() - .map(|(k, v)| format!("{k}: {}", print_schema(v))) - .collect(); - format!("OBJECT<{}>", parts.join(", ")) - } - } -} - /// Retrieve schema text from a VARIANT scalar or array (row-wise for arrays). fn infer_variant_schema(variant: &ColumnarValue) -> Result { match variant { @@ -385,14 +399,14 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { let schema = VariantSchema::from(&v); Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - print_schema(&schema), + schema.to_string(), )))) } ColumnarValue::Array(array) => { let variant_array = VariantArray::try_new(array.as_ref())?; let out = variant_array .iter() - .map(|v| v.map(|v| print_schema(&VariantSchema::from(&v)))) + .map(|v| v.map(|v| VariantSchema::from(&v).to_string())) .collect::>(); let out: StringViewArray = out.into(); @@ -403,7 +417,7 @@ fn infer_variant_schema(variant: &ColumnarValue) -> Result { #[cfg(test)] mod tests { - use super::{VariantSchema, print_schema}; + use super::VariantSchema; use arrow_schema::DataType; use std::collections::BTreeMap; @@ -424,6 +438,6 @@ mod tests { let decoded = VariantSchema::from_state_bytes(&bytes).expect("round-trip decode"); assert_eq!(decoded, schema); - assert_eq!(print_schema(&decoded), print_schema(&schema)); + assert_eq!(decoded.to_string(), schema.to_string()); } } From c555d302fcdf69d0f448fe3df23fafdd77eb8274 Mon Sep 17 00:00:00 2001 From: sdf-jkl Date: Tue, 3 Mar 2026 20:27:22 -0500 Subject: [PATCH 32/32] move merge functions to methods --- src/variant_schema.rs | 135 +++++++++++++++++++++--------------------- 1 file changed, 69 insertions(+), 66 deletions(-) diff --git a/src/variant_schema.rs b/src/variant_schema.rs index c32e414..b25b1c2 100644 --- a/src/variant_schema.rs +++ b/src/variant_schema.rs @@ -78,6 +78,12 @@ pub enum VariantSchema { Variant, } +impl Default for VariantSchema { + fn default() -> Self { + Self::Primitive(DataType::Null) + } +} + impl std::fmt::Display for VariantSchema { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fmt_schema(self, f) @@ -85,7 +91,7 @@ impl std::fmt::Display for VariantSchema { } /// Prints schema in a presentable manner -pub fn fmt_schema(schema: &VariantSchema, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +fn fmt_schema(schema: &VariantSchema, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match schema { VariantSchema::Primitive(dt) => write!(f, "{dt}"), @@ -128,8 +134,8 @@ impl From<&Variant<'_, '_>> for VariantSchema { let inner = list .iter() .map(|v| Self::from(&v)) - .try_fold(VariantSchema::Primitive(DataType::Null), |acc, next| { - let merged = merge_variant_schema(acc, next); + .try_fold(VariantSchema::default(), |acc, next| { + let merged = acc.merged(next); if merged == VariantSchema::Variant { Err(merged) } else { @@ -230,6 +236,66 @@ impl VariantSchema { })?; Self::try_from(&value) } + + pub fn merged(mut self, incoming: Self) -> Self { + self.merge_from(&incoming); + self + } + + pub fn merge_from(&mut self, incoming: &Self) { + use VariantSchema::*; + + if matches!(self, Variant) || matches!(incoming, Variant) { + *self = Variant; + return; + } + + if matches!(incoming, Primitive(DataType::Null)) { + return; + } + + if matches!(self, Primitive(DataType::Null)) { + *self = incoming.clone(); + return; + } + + match incoming { + Primitive(p2) => { + if let Primitive(p1) = self { + let merged = merge_primitives(p1.clone(), p2.clone()) + .map(Primitive) + .unwrap_or(Variant); + *self = merged; + } else { + *self = Variant; + } + } + Array(b) => { + if let Array(a) = self { + a.as_mut().merge_from(b.as_ref()); + } else { + *self = Variant; + } + } + Object(b) => { + if let Object(a) = self { + for (k, v_b) in b { + match a.entry(k.clone()) { + Entry::Occupied(mut occ) => occ.get_mut().merge_from(v_b), + Entry::Vacant(vac) => { + vac.insert(v_b.clone()); + } + } + } + } else { + *self = Variant; + } + } + Variant => { + *self = Variant; + } + } + } } fn schema_to_json(schema: &VariantSchema) -> Value { @@ -323,69 +389,6 @@ fn merge_primitives(a: DataType, b: DataType) -> Option { } } -/// Merges two inferred Variant schemas into a common schema. -/// Returns VARIANT if no common schema can be determined. -pub fn merge_variant_schema(a: VariantSchema, b: VariantSchema) -> VariantSchema { - let mut merged = a; - merge_variant_schema_from(&mut merged, &b); - merged -} - -pub fn merge_variant_schema_from(target: &mut VariantSchema, incoming: &VariantSchema) { - use VariantSchema::*; - - if matches!(target, Variant) || matches!(incoming, Variant) { - *target = Variant; - return; - } - - if matches!(incoming, Primitive(DataType::Null)) { - return; - } - - if matches!(target, Primitive(DataType::Null)) { - *target = incoming.clone(); - return; - } - - match incoming { - Primitive(p2) => { - if let Primitive(p1) = target { - let merged = merge_primitives(p1.clone(), p2.clone()) - .map(Primitive) - .unwrap_or(Variant); - *target = merged; - } else { - *target = Variant; - } - } - Array(b) => { - if let Array(a) = target { - merge_variant_schema_from(a.as_mut(), b.as_ref()); - } else { - *target = Variant; - } - } - Object(b) => { - if let Object(a) = target { - for (k, v_b) in b { - match a.entry(k.clone()) { - Entry::Occupied(mut occ) => merge_variant_schema_from(occ.get_mut(), v_b), - Entry::Vacant(vac) => { - vac.insert(v_b.clone()); - } - } - } - } else { - *target = Variant; - } - } - Variant => { - *target = Variant; - } - } -} - /// Retrieve schema text from a VARIANT scalar or array (row-wise for arrays). fn infer_variant_schema(variant: &ColumnarValue) -> Result { match variant {