diff --git a/src/lib.rs b/src/lib.rs index db82652..33de1cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ mod cast_to_variant; mod is_variant_null; mod json_to_variant; mod variant_get; +mod variant_get_float; mod variant_get_int; mod variant_get_str; mod variant_list_construct; @@ -22,6 +23,7 @@ pub use cast_to_variant::*; pub use is_variant_null::*; pub use json_to_variant::*; pub use variant_get::*; +pub use variant_get_float::*; pub use variant_get_int::*; pub use variant_get_str::*; pub use variant_list_construct::*; diff --git a/src/variant_get_float.rs b/src/variant_get_float.rs new file mode 100644 index 0000000..ed85021 --- /dev/null +++ b/src/variant_get_float.rs @@ -0,0 +1,297 @@ +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array}; +use arrow_schema::DataType; +use datafusion::{ + error::Result, + logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, + }, + scalar::ScalarValue, +}; +use parquet_variant::Variant; + +use crate::shared::invoke_variant_get_typed; + +/// Extracts a floating-point value from a Variant by path. +/// +/// `variant_get_float(variant, path)` returns the value at `path` as a `FLOAT64`. +/// - Float values are returned as-is +/// - Integer values are returned as `FLOAT64` (large values may lose precision) +/// - Non-numeric values return NULL +/// - Returns NULL if the path does not exist +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct VariantGetFloatUdf { + signature: Signature, +} + +impl Default for VariantGetFloatUdf { + fn default() -> Self { + Self { + signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable), + } + } +} + +fn scalar_from_float(value: Option) -> ScalarValue { + ScalarValue::Float64(value) +} + +fn float_array_from_values(values: Vec>) -> ArrayRef { + Arc::new(values.into_iter().collect::()) +} + +fn extract_float(value: Variant<'_, '_>) -> Result> { + Ok(value + .as_f64() + .or_else(|| value.as_int64().map(|int| int as f64))) +} + +impl ScalarUDFImpl for VariantGetFloatUdf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "variant_get_float" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + invoke_variant_get_typed( + args, + scalar_from_float, + float_array_from_values, + extract_float, + ) + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, ArrayRef, Float64Array, StringViewArray}; + use datafusion::logical_expr::ColumnarValue; + use datafusion::scalar::ScalarValue; + + use crate::shared::{ + build_variant_get_args, standard_variant_get_arg_fields, variant_array_from_json_rows, + variant_scalar_from_json, + }; + + use super::*; + + #[test] + fn test_scalar_float_value() { + let variant_input = variant_scalar_from_json(serde_json::json!({ + "name": "norm", + "price": 50.5 + })); + + let udf = VariantGetFloatUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("price".to_string()))), + DataType::Float64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) = result else { + panic!("expected Float64 scalar"); + }; + + assert_eq!(v, 50.5); + } + + #[test] + fn test_scalar_integer_value() { + let variant_input = variant_scalar_from_json(serde_json::json!({ + "name": "norm", + "age": 50 + })); + + let udf = VariantGetFloatUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("age".to_string()))), + DataType::Float64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) = result else { + panic!("expected Float64 scalar"); + }; + + assert_eq!(v, 50.0); + } + + #[test] + fn test_scalar_large_integer_value() { + let variant_input = variant_scalar_from_json(serde_json::json!({ + "n": 9007199254740993_i64 + })); + + let udf = VariantGetFloatUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("n".to_string()))), + DataType::Float64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) = result else { + panic!("expected Float64 scalar"); + }; + + // `f64` cannot exactly represent all i64 values; this mirrors json_get_float behavior. + assert_eq!(v, 9_007_199_254_740_992.0); + } + + #[test] + fn test_scalar_non_numeric_value_returns_null() { + let variant_input = variant_scalar_from_json(serde_json::json!({ + "name": "norm", + "age": 50.5 + })); + + let udf = VariantGetFloatUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("name".to_string()))), + DataType::Float64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Scalar(ScalarValue::Float64(None)) = result else { + panic!("expected NULL Float64 scalar"); + }; + } + + #[test] + fn test_scalar_missing_path() { + let variant_input = variant_scalar_from_json(serde_json::json!({"name": "norm"})); + + let udf = VariantGetFloatUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("missing".to_string()))), + DataType::Float64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Scalar(ScalarValue::Float64(None)) = result else { + panic!("expected NULL Float64 scalar"); + }; + } + + #[test] + fn test_array_variant_scalar_path() { + let json_rows = vec![ + serde_json::json!({"name": "alice", "price": 30.25}), + serde_json::json!({"name": "bob", "price": 40}), + serde_json::json!({"name": "charlie"}), + ]; + + let variant_array = variant_array_from_json_rows(&json_rows); + + let udf = VariantGetFloatUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Array(variant_array), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("price".to_string()))), + DataType::Float64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Array(arr) = result else { + panic!("expected array output"); + }; + + let float_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(float_arr.len(), 3); + assert_eq!(float_arr.value(0), 30.25); + assert_eq!(float_arr.value(1), 40.0); + assert!(float_arr.is_null(2)); + } + + #[test] + fn test_array_variant_array_paths() { + let json_rows = vec![ + serde_json::json!({"name": "alice", "price": 30.25}), + serde_json::json!({"name": "bob", "price": 40}), + ]; + + let variant_array = variant_array_from_json_rows(&json_rows); + let path_array: ArrayRef = Arc::new(StringViewArray::from(vec!["price", "name"])); + + let udf = VariantGetFloatUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Array(variant_array), + ColumnarValue::Array(path_array), + DataType::Float64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Array(arr) = result else { + panic!("expected array output"); + }; + + let float_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(float_arr.len(), 2); + assert_eq!(float_arr.value(0), 30.25); + assert!(float_arr.is_null(1)); + } + + #[test] + fn test_scalar_variant_array_paths() { + let variant_input = variant_scalar_from_json(serde_json::json!({ + "name": "alice", + "price": 30.25, + "count": 3 + })); + + let path_array: ArrayRef = Arc::new(StringViewArray::from(vec![ + "price", "count", "name", "missing", + ])); + + let udf = VariantGetFloatUdf::default(); + let args = build_variant_get_args( + ColumnarValue::Scalar(variant_input), + ColumnarValue::Array(path_array), + DataType::Float64, + standard_variant_get_arg_fields(), + ); + + let result = udf.invoke_with_args(args).unwrap(); + + let ColumnarValue::Array(arr) = result else { + panic!("expected array output"); + }; + + let float_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(float_arr.len(), 4); + assert_eq!(float_arr.value(0), 30.25); + assert_eq!(float_arr.value(1), 3.0); + assert!(float_arr.is_null(2)); + assert!(float_arr.is_null(3)); + } +} diff --git a/tests/sqllogictests.rs b/tests/sqllogictests.rs index 61ce4cb..e48431f 100644 --- a/tests/sqllogictests.rs +++ b/tests/sqllogictests.rs @@ -1,10 +1,10 @@ use datafusion::{logical_expr::ScalarUDF, prelude::*}; use datafusion_sqllogictest::{DataFusion, TestContext}; use datafusion_variant::{ - CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetFieldUdf, VariantGetIntUdf, - VariantGetStrUdf, VariantGetUdf, VariantListConstruct, VariantListDelete, VariantListInsert, - VariantObjectConstruct, VariantObjectDelete, VariantObjectInsert, VariantObjectKeys, - VariantPretty, VariantToJsonUdf, + CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetFieldUdf, VariantGetFloatUdf, + VariantGetIntUdf, VariantGetStrUdf, VariantGetUdf, VariantListConstruct, VariantListDelete, + VariantListInsert, VariantObjectConstruct, VariantObjectDelete, VariantObjectInsert, + VariantObjectKeys, VariantPretty, VariantToJsonUdf, }; use indicatif::ProgressBar; use sqllogictest::strict_column_validator; @@ -51,6 +51,7 @@ async fn run_sqllogictests() -> Result<(), Box> { ctx.register_udf(ScalarUDF::new_from_impl(IsVariantNullUdf::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantGetUdf::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantGetStrUdf::default())); + ctx.register_udf(ScalarUDF::new_from_impl(VariantGetFloatUdf::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantGetIntUdf::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantGetFieldUdf::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantPretty::default())); diff --git a/tests/test_files/variant_get_float.slt b/tests/test_files/variant_get_float.slt new file mode 100644 index 0000000..b76d4eb --- /dev/null +++ b/tests/test_files/variant_get_float.slt @@ -0,0 +1,86 @@ +statement ok +CREATE TABLE json_data (id INT, json_str TEXT) AS VALUES + (1, '{"name": "Alice", "age": 30}'), + (2, '{"name": "Bob", "age": 25}'), + (3, '{"items": [1, 2, 3], "count": 3}'), + (4, 'null'), + (5, '"simple string"'), + (6, '123'), + (7, 'true'), + (8, '{"pi": 3.14}'); + +# Numeric values are returned as Float64 +query R +select variant_get_float(json_to_variant(json_str), 'age') from json_data; +---- +30 +25 +NULL +NULL +NULL +NULL +NULL +NULL + +# Missing paths return NULL +query R +select variant_get_float(json_to_variant(json_str), 'nonexistent') from json_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +# Non-numeric values return NULL +query R +select variant_get_float(json_to_variant(json_str), 'name') from json_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +# Float values are returned as Float64 +query R +select variant_get_float(json_to_variant(json_str), 'pi') from json_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +3.14 + +# Scalar variant with integer value +query R +select variant_get_float(json_to_variant('{"count": 42}'), 'count'); +---- +42 + +# Scalar variant with string value returns NULL +query R +select variant_get_float(json_to_variant('{"greeting": "hello world"}'), 'greeting'); +---- +NULL + +# Nested numeric path +query R +select variant_get_float(json_to_variant('{"obj": {"a": 1}}'), 'obj.a'); +---- +1 + +# Large integers are coerced to Float64 (with potential precision loss) +query R +select variant_get_float(json_to_variant('{"n": 9007199254740993}'), 'n'); +---- +9007199254740992