diff --git a/src/lib.rs b/src/lib.rs index 41e5595..db82652 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ mod cast_to_variant; mod is_variant_null; mod json_to_variant; mod variant_get; +mod variant_get_int; mod variant_get_str; mod variant_list_construct; mod variant_list_delete; @@ -21,6 +22,7 @@ pub use cast_to_variant::*; pub use is_variant_null::*; pub use json_to_variant::*; pub use variant_get::*; +pub use variant_get_int::*; pub use variant_get_str::*; pub use variant_list_construct::*; pub use variant_list_delete::*; diff --git a/src/shared.rs b/src/shared.rs index 8ebfb43..21858fc 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -1,11 +1,17 @@ use std::sync::Arc; -use arrow::array::{Array, cast::AsArray}; +#[cfg(test)] +use arrow::array::StructArray; +use arrow::array::{Array, ArrayRef, cast::AsArray}; +#[cfg(test)] +use arrow_schema::Fields; use arrow_schema::extension::ExtensionType; use arrow_schema::{DataType, Field}; use datafusion::common::exec_datafusion_err; use datafusion::error::Result; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion::{common::exec_err, scalar::ScalarValue}; +use parquet_variant::{Variant, VariantPath}; use parquet_variant_compute::{VariantArray, VariantType}; #[cfg(test)] @@ -118,6 +124,129 @@ pub fn try_parse_string_columnar(array: &Arc) -> Result( + variant_array: &VariantArray, + index: usize, + path: &str, + extract: for<'m, 'v> fn(Variant<'m, 'v>) -> Result>, +) -> Result> { + let Some(variant) = variant_array.iter().nth(index).flatten() else { + return Ok(None); + }; + + let variant_path = VariantPath::from(path); + let Some(value) = variant.get_path(&variant_path) else { + return Ok(None); + }; + + extract(value) +} + +pub fn variant_get_array_values( + variant_array: &VariantArray, + path: &str, + extract: for<'m, 'v> fn(Variant<'m, 'v>) -> Result>, +) -> Result>> { + let variant_path = VariantPath::from(path); + + variant_array + .iter() + .map(|maybe_variant| { + let Some(variant) = maybe_variant else { + return Ok(None); + }; + + let Some(value) = variant.get_path(&variant_path) else { + return Ok(None); + }; + + extract(value) + }) + .collect() +} + +pub fn invoke_variant_get_typed( + args: ScalarFunctionArgs, + scalar_from_option: fn(Option) -> ScalarValue, + array_from_values: fn(Vec>) -> ArrayRef, + extract: for<'m, 'v> fn(Variant<'m, 'v>) -> Result>, +) -> Result { + let (variant_arg, path_arg) = match args.args.as_slice() { + [variant_arg, path_arg] => (variant_arg, path_arg), + _ => return exec_err!("expected 2 arguments"), + }; + + let variant_field = args + .arg_fields + .first() + .ok_or_else(|| exec_datafusion_err!("expected argument field"))?; + + try_field_as_variant_array(variant_field.as_ref())?; + + let out = match (variant_arg, path_arg) { + (ColumnarValue::Array(variant_array), ColumnarValue::Scalar(path_scalar)) => { + let path = try_parse_string_scalar(path_scalar)? + .map(|s| s.as_str()) + .unwrap_or_default(); + + let variant_array = VariantArray::try_new(variant_array.as_ref())?; + let values = variant_get_array_values(&variant_array, path, extract)?; + ColumnarValue::Array(array_from_values(values)) + } + (ColumnarValue::Scalar(scalar_variant), ColumnarValue::Scalar(path_scalar)) => { + let ScalarValue::Struct(variant_array) = scalar_variant else { + return exec_err!("expected struct array"); + }; + + let path = try_parse_string_scalar(path_scalar)? + .map(|s| s.as_str()) + .unwrap_or_default(); + + let variant_array = VariantArray::try_new(variant_array.as_ref())?; + let value = variant_get_single_value(&variant_array, 0, path, extract)?; + + ColumnarValue::Scalar(scalar_from_option(value)) + } + (ColumnarValue::Array(variant_array), ColumnarValue::Array(paths)) => { + if variant_array.len() != paths.len() { + return exec_err!("expected variant array and paths to be of same length"); + } + + let paths = try_parse_string_columnar(paths)?; + let variant_array = VariantArray::try_new(variant_array.as_ref())?; + + let values: Vec> = (0..variant_array.len()) + .map(|i| { + let path = paths[i].unwrap_or_default(); + variant_get_single_value(&variant_array, i, path, extract) + }) + .collect::>()?; + + ColumnarValue::Array(array_from_values(values)) + } + (ColumnarValue::Scalar(scalar_variant), ColumnarValue::Array(paths)) => { + let ScalarValue::Struct(variant_array) = scalar_variant else { + return exec_err!("expected struct array"); + }; + + let variant_array = VariantArray::try_new(variant_array.as_ref())?; + let paths = try_parse_string_columnar(paths)?; + + let values: Vec> = paths + .iter() + .map(|path| { + let path = path.unwrap_or_default(); + variant_get_single_value(&variant_array, 0, path, extract) + }) + .collect::>()?; + + ColumnarValue::Array(array_from_values(values)) + } + }; + + Ok(out) +} + /// This is similar to anyhow's ensure! macro /// If the `pred` fails, it will return a DataFusionError pub fn ensure(pred: bool, err_msg: &str) -> Result<()> { @@ -139,6 +268,50 @@ pub fn build_variant_array_from_json(value: &serde_json::Value) -> VariantArray builder.build() } +#[cfg(test)] +pub fn variant_scalar_from_json(json: serde_json::Value) -> ScalarValue { + let mut builder = VariantArrayBuilder::new(1); + builder.append_json(json.to_string().as_str()).unwrap(); + ScalarValue::Struct(Arc::new(builder.build().into())) +} + +#[cfg(test)] +pub fn variant_array_from_json_rows(json_rows: &[serde_json::Value]) -> ArrayRef { + let mut builder = VariantArrayBuilder::new(json_rows.len()); + for value in json_rows { + builder.append_json(value.to_string().as_str()).unwrap(); + } + let variant_array: StructArray = builder.build().into(); + Arc::new(variant_array) as ArrayRef +} + +#[cfg(test)] +pub fn standard_variant_get_arg_fields() -> Vec> { + vec![ + Arc::new( + Field::new("input", DataType::Struct(Fields::empty()), true) + .with_extension_type(VariantType), + ), + Arc::new(Field::new("path", DataType::Utf8, true)), + ] +} + +#[cfg(test)] +pub fn build_variant_get_args( + variant_input: ColumnarValue, + path: ColumnarValue, + return_data_type: DataType, + arg_fields: Vec>, +) -> ScalarFunctionArgs { + ScalarFunctionArgs { + args: vec![variant_input, path], + return_field: Arc::new(Field::new("result", return_data_type, true)), + arg_fields, + number_rows: Default::default(), + config_options: Default::default(), + } +} + #[cfg(test)] #[allow(unused)] pub fn build_variant_array_from_json_array(jsons: &[Option]) -> VariantArray { diff --git a/src/variant_get.rs b/src/variant_get.rs index ebfa1a5..dd5e728 100644 --- a/src/variant_get.rs +++ b/src/variant_get.rs @@ -316,38 +316,17 @@ impl ScalarUDFImpl for VariantGetFieldUdf { #[cfg(test)] mod tests { + use super::*; + use crate::shared::{ + standard_variant_get_arg_fields, variant_array_from_json_rows, variant_scalar_from_json, + }; use arrow::array::{Array, BinaryViewArray, Int64Array}; - use arrow_schema::{Field, Fields}; + use arrow_schema::Field; use datafusion::logical_expr::{ReturnFieldArgs, ScalarFunctionArgs}; use parquet_variant::Variant; - use parquet_variant_compute::{VariantArrayBuilder, VariantType}; - use parquet_variant_json::JsonToVariant; - - use super::*; - - fn variant_scalar_from_json(json: serde_json::Value) -> ScalarValue { - let mut builder = VariantArrayBuilder::new(1); - builder.append_json(json.to_string().as_str()).unwrap(); - ScalarValue::Struct(Arc::new(builder.build().into())) - } - - fn variant_array_from_json_rows(json_rows: &[serde_json::Value]) -> ArrayRef { - let mut builder = VariantArrayBuilder::new(json_rows.len()); - for value in json_rows { - builder.append_json(value.to_string().as_str()).unwrap(); - } - let variant_array: StructArray = builder.build().into(); - Arc::new(variant_array) as ArrayRef - } fn standard_arg_fields(with_type_hint: bool) -> Vec { - let mut fields = vec![ - Arc::new( - Field::new("input", DataType::Struct(Fields::empty()), true) - .with_extension_type(VariantType), - ), - Arc::new(Field::new("path", DataType::Utf8, true)), - ]; + let mut fields = standard_variant_get_arg_fields(); if with_type_hint { fields.push(Arc::new(Field::new("type", DataType::Utf8, true))); } diff --git a/src/variant_get_int.rs b/src/variant_get_int.rs new file mode 100644 index 0000000..903c1e2 --- /dev/null +++ b/src/variant_get_int.rs @@ -0,0 +1,258 @@ +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array}; +use arrow_schema::DataType; +use datafusion::{ + error::Result, + logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, + }, + scalar::ScalarValue, +}; +use parquet_variant::Variant; + +use crate::shared::invoke_variant_get_typed; + +/// Extracts an integer value from a Variant by path. +/// +/// `variant_get_int(variant, path)` returns the value at `path` as an `INT64`. +/// - Integer values are returned as-is (with widening when needed) +/// - Non-integer values return NULL +/// - Returns NULL if the path does not exist +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct VariantGetIntUdf { + signature: Signature, +} + +impl Default for VariantGetIntUdf { + fn default() -> Self { + Self { + signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable), + } + } +} + +fn scalar_from_int(value: Option) -> ScalarValue { + ScalarValue::Int64(value) +} + +fn int_array_from_values(values: Vec>) -> ArrayRef { + Arc::new(values.into_iter().collect::()) +} + +fn extract_int(value: Variant<'_, '_>) -> Result> { + Ok(value.as_int64()) +} + +impl ScalarUDFImpl for VariantGetIntUdf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "variant_get_int" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + invoke_variant_get_typed(args, scalar_from_int, int_array_from_values, extract_int) + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, ArrayRef, Int64Array, StringViewArray}; + use datafusion::logical_expr::ColumnarValue; + use datafusion::scalar::ScalarValue; + + use crate::shared::{ + build_variant_get_args, standard_variant_get_arg_fields, variant_array_from_json_rows, + variant_scalar_from_json, + }; + + use super::*; + + #[test] + fn test_scalar_integer_value() { + let variant_input = variant_scalar_from_json(serde_json::json!({ + "name": "norm", + "age": 50 + })); + + let udf = VariantGetIntUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("age".to_string()))), + DataType::Int64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) = result else { + panic!("expected Int64 scalar"); + }; + + assert_eq!(v, 50); + } + + #[test] + fn test_scalar_non_integer_value_returns_null() { + let variant_input = variant_scalar_from_json(serde_json::json!({ + "name": "norm", + "age": 50.5 + })); + + let udf = VariantGetIntUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("name".to_string()))), + DataType::Int64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Scalar(ScalarValue::Int64(None)) = result else { + panic!("expected NULL Int64 scalar"); + }; + } + + #[test] + fn test_scalar_float_value_returns_null() { + let variant_input = variant_scalar_from_json(serde_json::json!({ + "price": 10.5 + })); + + let udf = VariantGetIntUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("price".to_string()))), + DataType::Int64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Scalar(ScalarValue::Int64(None)) = result else { + panic!("expected NULL Int64 scalar"); + }; + } + + #[test] + fn test_scalar_missing_path() { + let variant_input = variant_scalar_from_json(serde_json::json!({"name": "norm"})); + + let udf = VariantGetIntUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("missing".to_string()))), + DataType::Int64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Scalar(ScalarValue::Int64(None)) = result else { + panic!("expected NULL Int64 scalar"); + }; + } + + #[test] + fn test_array_variant_scalar_path() { + let json_rows = vec![ + serde_json::json!({"name": "alice", "age": 30}), + serde_json::json!({"name": "bob", "age": 40}), + serde_json::json!({"name": "charlie"}), + ]; + + let variant_array = variant_array_from_json_rows(&json_rows); + + let udf = VariantGetIntUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Array(variant_array), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("age".to_string()))), + DataType::Int64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Array(arr) = result else { + panic!("expected array output"); + }; + + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.len(), 3); + assert_eq!(int_arr.value(0), 30); + assert_eq!(int_arr.value(1), 40); + assert!(int_arr.is_null(2)); + } + + #[test] + fn test_array_variant_array_paths() { + let json_rows = vec![ + serde_json::json!({"name": "alice", "age": 30}), + serde_json::json!({"name": "bob", "age": 40}), + ]; + + let variant_array = variant_array_from_json_rows(&json_rows); + let path_array: ArrayRef = Arc::new(StringViewArray::from(vec!["age", "name"])); + + let udf = VariantGetIntUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Array(variant_array), + ColumnarValue::Array(path_array), + DataType::Int64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Array(arr) = result else { + panic!("expected array output"); + }; + + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.len(), 2); + assert_eq!(int_arr.value(0), 30); + assert!(int_arr.is_null(1)); + } + + #[test] + fn test_scalar_variant_array_paths() { + let variant_input = variant_scalar_from_json(serde_json::json!({ + "name": "alice", + "age": 30 + })); + + let path_array: ArrayRef = Arc::new(StringViewArray::from(vec!["age", "name", "missing"])); + + let udf = VariantGetIntUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Array(path_array), + DataType::Int64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Array(arr) = result else { + panic!("expected array output"); + }; + + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.len(), 3); + assert_eq!(int_arr.value(0), 30); + assert!(int_arr.is_null(1)); + assert!(int_arr.is_null(2)); + } +} diff --git a/src/variant_get_str.rs b/src/variant_get_str.rs index a72db84..bdc11d8 100644 --- a/src/variant_get_str.rs +++ b/src/variant_get_str.rs @@ -1,22 +1,18 @@ use std::sync::Arc; -use arrow::array::StringViewArray; +use arrow::array::{ArrayRef, StringViewArray}; use arrow_schema::DataType; use datafusion::{ - common::{exec_datafusion_err, exec_err}, - error::{DataFusionError, Result}, + error::Result, logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }, scalar::ScalarValue, }; -use parquet_variant::VariantPath; -use parquet_variant_compute::VariantArray; +use parquet_variant::Variant; use parquet_variant_json::VariantToJson; -use crate::shared::{ - try_field_as_variant_array, try_parse_string_columnar, try_parse_string_scalar, -}; +use crate::shared::invoke_variant_get_typed; /// Extracts a string value from a Variant by path. /// @@ -39,6 +35,24 @@ impl Default for VariantGetStrUdf { } } +fn scalar_from_string(value: Option) -> ScalarValue { + ScalarValue::Utf8View(value) +} + +fn string_array_from_values(values: Vec>) -> ArrayRef { + let out: StringViewArray = values.into_iter().collect(); + Arc::new(out) +} + +fn extract_string(value: Variant<'_, '_>) -> Result> { + if let Some(s) = value.as_string() { + Ok(Some(s.to_string())) + } else { + // If the path resolves to a non-string variant, return its JSON string. + Ok(Some(value.to_json_string()?)) + } +} + impl ScalarUDFImpl for VariantGetStrUdf { fn as_any(&self) -> &dyn std::any::Any { self @@ -57,180 +71,27 @@ impl ScalarUDFImpl for VariantGetStrUdf { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let (variant_arg, path_arg) = match args.args.as_slice() { - [variant_arg, path_arg] => (variant_arg, path_arg), - _ => return exec_err!("expected 2 arguments"), - }; - - let variant_field = args - .arg_fields - .first() - .ok_or_else(|| exec_datafusion_err!("expected argument field"))?; - - try_field_as_variant_array(variant_field.as_ref())?; - - let out = match (variant_arg, path_arg) { - (ColumnarValue::Array(variant_array), ColumnarValue::Scalar(path_scalar)) => { - let path = try_parse_string_scalar(path_scalar)? - .map(|s| s.as_str()) - .unwrap_or_default(); - - let variant_array = VariantArray::try_new(variant_array.as_ref())?; - let out = variant_array_get_str(&variant_array, path)?; - - ColumnarValue::Array(Arc::new(out)) - } - (ColumnarValue::Scalar(scalar_variant), ColumnarValue::Scalar(path_scalar)) => { - let ScalarValue::Struct(variant_array) = scalar_variant else { - return exec_err!("expected struct array"); - }; - - let path = try_parse_string_scalar(path_scalar)? - .map(|s| s.as_str()) - .unwrap_or_default(); - - let variant_array = VariantArray::try_new(variant_array.as_ref())?; - let result = variant_get_str_single(&variant_array, 0, path)?; - - ColumnarValue::Scalar(ScalarValue::Utf8View(result)) - } - (ColumnarValue::Array(variant_array), ColumnarValue::Array(paths)) => { - if variant_array.len() != paths.len() { - return exec_err!("expected variant array and paths to be of same length"); - } - - let paths = try_parse_string_columnar(paths)?; - let variant_array = VariantArray::try_new(variant_array.as_ref())?; - - let results: Vec> = (0..variant_array.len()) - .map(|i| { - let path = paths[i].unwrap_or_default(); - variant_get_str_single(&variant_array, i, path) - }) - .collect::>()?; - - let out: StringViewArray = results.into_iter().collect(); - ColumnarValue::Array(Arc::new(out)) - } - (ColumnarValue::Scalar(scalar_variant), ColumnarValue::Array(paths)) => { - let ScalarValue::Struct(variant_array) = scalar_variant else { - return exec_err!("expected struct array"); - }; - - let variant_array = VariantArray::try_new(variant_array.as_ref())?; - let paths = try_parse_string_columnar(paths)?; - - let results: Vec> = paths - .iter() - .map(|path| { - let path = path.unwrap_or_default(); - variant_get_str_single(&variant_array, 0, path) - }) - .collect::>()?; - - let out: StringViewArray = results.into_iter().collect(); - ColumnarValue::Array(Arc::new(out)) - } - }; - - Ok(out) - } -} - -fn variant_get_str_single( - variant_array: &VariantArray, - index: usize, - path: &str, -) -> Result> { - let Some(variant) = variant_array.iter().nth(index).flatten() else { - return Ok(None); - }; - - let variant_path = VariantPath::from(path); - let Some(value) = variant.get_path(&variant_path) else { - return Ok(None); - }; - - if let Some(s) = value.as_string() { - Ok(Some(s.to_string())) - } else { - // if the path resolves to a non-string variant, return its JSON string - Ok(Some(value.to_json_string()?)) + invoke_variant_get_typed( + args, + scalar_from_string, + string_array_from_values, + extract_string, + ) } } -fn variant_array_get_str(variant_array: &VariantArray, path: &str) -> Result { - let variant_path = VariantPath::from(path); - - let results: Vec> = variant_array - .iter() - .map(|maybe_variant| { - let Some(variant) = maybe_variant else { - return Ok(None); - }; - - let Some(value) = variant.get_path(&variant_path) else { - return Ok(None); - }; - - if let Some(s) = value.as_string() { - Ok(Some(s.to_string())) - } else { - Ok(Some(value.to_json_string()?)) - } - }) - .collect::>()?; - - Ok(results.into_iter().collect()) -} - #[cfg(test)] mod tests { - use arrow::array::{Array, ArrayRef, StructArray}; - use arrow_schema::{Field, Fields}; - use parquet_variant_compute::{VariantArrayBuilder, VariantType}; - use parquet_variant_json::JsonToVariant; + use arrow::array::{Array, ArrayRef, StringViewArray}; + use datafusion::logical_expr::ColumnarValue; + use datafusion::scalar::ScalarValue; - use super::*; - - fn variant_scalar_from_json(json: serde_json::Value) -> ScalarValue { - let mut builder = VariantArrayBuilder::new(1); - builder.append_json(json.to_string().as_str()).unwrap(); - ScalarValue::Struct(Arc::new(builder.build().into())) - } - - fn variant_array_from_json_rows(json_rows: &[serde_json::Value]) -> ArrayRef { - let mut builder = VariantArrayBuilder::new(json_rows.len()); - for value in json_rows { - builder.append_json(value.to_string().as_str()).unwrap(); - } - let variant_array: StructArray = builder.build().into(); - Arc::new(variant_array) as ArrayRef - } - - fn standard_arg_fields() -> Vec> { - vec![ - Arc::new( - Field::new("input", DataType::Struct(Fields::empty()), true) - .with_extension_type(VariantType), - ), - Arc::new(Field::new("path", DataType::Utf8, true)), - ] - } + use crate::shared::{ + build_variant_get_args, standard_variant_get_arg_fields, variant_array_from_json_rows, + variant_scalar_from_json, + }; - fn build_args( - variant_input: ColumnarValue, - path: ColumnarValue, - arg_fields: Vec>, - ) -> ScalarFunctionArgs { - ScalarFunctionArgs { - args: vec![variant_input, path], - return_field: Arc::new(Field::new("result", DataType::Utf8View, true)), - arg_fields, - number_rows: Default::default(), - config_options: Default::default(), - } - } + use super::*; #[test] fn test_scalar_string_value() { @@ -240,10 +101,11 @@ mod tests { })); let udf = VariantGetStrUdf::default(); - let args = build_args( + let args = build_variant_get_args( ColumnarValue::Scalar(variant_input), ColumnarValue::Scalar(ScalarValue::Utf8(Some("name".to_string()))), - standard_arg_fields(), + DataType::Utf8View, + standard_variant_get_arg_fields(), ); let result = udf.invoke_with_args(args).unwrap(); @@ -263,10 +125,11 @@ mod tests { })); let udf = VariantGetStrUdf::default(); - let args = build_args( + let args = build_variant_get_args( ColumnarValue::Scalar(variant_input), ColumnarValue::Scalar(ScalarValue::Utf8(Some("age".to_string()))), - standard_arg_fields(), + DataType::Utf8View, + standard_variant_get_arg_fields(), ); let result = udf.invoke_with_args(args).unwrap(); @@ -283,10 +146,11 @@ mod tests { let variant_input = variant_scalar_from_json(serde_json::json!({"name": "norm"})); let udf = VariantGetStrUdf::default(); - let args = build_args( + let args = build_variant_get_args( ColumnarValue::Scalar(variant_input), ColumnarValue::Scalar(ScalarValue::Utf8(Some("missing".to_string()))), - standard_arg_fields(), + DataType::Utf8View, + standard_variant_get_arg_fields(), ); let result = udf.invoke_with_args(args).unwrap(); @@ -303,10 +167,11 @@ mod tests { })); let udf = VariantGetStrUdf::default(); - let args = build_args( + let args = build_variant_get_args( ColumnarValue::Scalar(variant_input), ColumnarValue::Scalar(ScalarValue::Utf8(Some("obj".to_string()))), - standard_arg_fields(), + DataType::Utf8View, + standard_variant_get_arg_fields(), ); let result = udf.invoke_with_args(args).unwrap(); @@ -324,10 +189,11 @@ mod tests { let variant_input = variant_scalar_from_json(serde_json::json!({"flag": true})); let udf = VariantGetStrUdf::default(); - let args = build_args( + let args = build_variant_get_args( ColumnarValue::Scalar(variant_input), ColumnarValue::Scalar(ScalarValue::Utf8(Some("flag".to_string()))), - standard_arg_fields(), + DataType::Utf8View, + standard_variant_get_arg_fields(), ); let result = udf.invoke_with_args(args).unwrap(); @@ -344,10 +210,11 @@ mod tests { let variant_input = variant_scalar_from_json(serde_json::json!({"key": null})); let udf = VariantGetStrUdf::default(); - let args = build_args( + let args = build_variant_get_args( ColumnarValue::Scalar(variant_input), ColumnarValue::Scalar(ScalarValue::Utf8(Some("key".to_string()))), - standard_arg_fields(), + DataType::Utf8View, + standard_variant_get_arg_fields(), ); let result = udf.invoke_with_args(args).unwrap(); @@ -370,10 +237,11 @@ mod tests { let variant_array = variant_array_from_json_rows(&json_rows); let udf = VariantGetStrUdf::default(); - let args = build_args( + let args = build_variant_get_args( ColumnarValue::Array(variant_array), ColumnarValue::Scalar(ScalarValue::Utf8(Some("name".to_string()))), - standard_arg_fields(), + DataType::Utf8View, + standard_variant_get_arg_fields(), ); let result = udf.invoke_with_args(args).unwrap(); @@ -397,14 +265,14 @@ mod tests { ]; let variant_array = variant_array_from_json_rows(&json_rows); - let path_array: ArrayRef = Arc::new(StringViewArray::from(vec!["name", "age"])); let udf = VariantGetStrUdf::default(); - let args = build_args( + let args = build_variant_get_args( ColumnarValue::Array(variant_array), ColumnarValue::Array(path_array), - standard_arg_fields(), + DataType::Utf8View, + standard_variant_get_arg_fields(), ); let result = udf.invoke_with_args(args).unwrap(); @@ -426,10 +294,11 @@ mod tests { })); let udf = VariantGetStrUdf::default(); - let args = build_args( + let args = build_variant_get_args( ColumnarValue::Scalar(variant_input), ColumnarValue::Scalar(ScalarValue::Utf8(Some("list".to_string()))), - standard_arg_fields(), + DataType::Utf8View, + standard_variant_get_arg_fields(), ); let result = udf.invoke_with_args(args).unwrap(); diff --git a/tests/sqllogictests.rs b/tests/sqllogictests.rs index cdbf6ce..61ce4cb 100644 --- a/tests/sqllogictests.rs +++ b/tests/sqllogictests.rs @@ -1,8 +1,8 @@ use datafusion::{logical_expr::ScalarUDF, prelude::*}; use datafusion_sqllogictest::{DataFusion, TestContext}; use datafusion_variant::{ - CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetFieldUdf, VariantGetStrUdf, - VariantGetUdf, VariantListConstruct, VariantListDelete, VariantListInsert, + CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetFieldUdf, VariantGetIntUdf, + VariantGetStrUdf, VariantGetUdf, VariantListConstruct, VariantListDelete, VariantListInsert, VariantObjectConstruct, VariantObjectDelete, VariantObjectInsert, VariantObjectKeys, VariantPretty, VariantToJsonUdf, }; @@ -51,6 +51,7 @@ async fn run_sqllogictests() -> Result<(), Box> { ctx.register_udf(ScalarUDF::new_from_impl(IsVariantNullUdf::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantGetUdf::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantGetStrUdf::default())); + ctx.register_udf(ScalarUDF::new_from_impl(VariantGetIntUdf::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantGetFieldUdf::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantPretty::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectConstruct::default())); diff --git a/tests/test_files/variant_get_int.slt b/tests/test_files/variant_get_int.slt new file mode 100644 index 0000000..b29c114 --- /dev/null +++ b/tests/test_files/variant_get_int.slt @@ -0,0 +1,79 @@ +statement ok +CREATE TABLE json_data (id INT, json_str TEXT) AS VALUES + (1, '{"name": "Alice", "age": 30}'), + (2, '{"name": "Bob", "age": 25}'), + (3, '{"items": [1, 2, 3], "count": 3}'), + (4, 'null'), + (5, '"simple string"'), + (6, '123'), + (7, 'true'), + (8, '{"pi": 3.14}'); + +# Integer values are returned as Int64 +query I +select variant_get_int(json_to_variant(json_str), 'age') from json_data; +---- +30 +25 +NULL +NULL +NULL +NULL +NULL +NULL + +# Missing paths return NULL +query I +select variant_get_int(json_to_variant(json_str), 'nonexistent') from json_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +# Non-integer values return NULL +query I +select variant_get_int(json_to_variant(json_str), 'name') from json_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query I +select variant_get_int(json_to_variant(json_str), 'pi') from json_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +# Scalar variant with integer value +query I +select variant_get_int(json_to_variant('{"count": 42}'), 'count'); +---- +42 + +# Scalar variant with string value returns NULL +query I +select variant_get_int(json_to_variant('{"greeting": "hello world"}'), 'greeting'); +---- +NULL + +# Nested integer path +query I +select variant_get_int(json_to_variant('{"obj": {"a": 1}}'), 'obj.a'); +---- +1