diff --git a/Cargo.toml b/Cargo.toml index 7e051ee..0cb7d79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,3 +35,6 @@ pedantic = { level = "deny", priority = -1 } [[bench]] name = "main" harness = false + +[patch.crates-io] +datafusion = { git = "https://github.com/pydantic/datafusion.git", branch = "pydantic-main" } diff --git a/src/json_as_text.rs b/src/json_as_text.rs index bfb4cb1..5599d0c 100644 --- a/src/json_as_text.rs +++ b/src/json_as_text.rs @@ -56,6 +56,24 @@ impl ScalarUDFImpl for JsonAsText { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } impl InvokeResult for StringArray { diff --git a/src/json_contains.rs b/src/json_contains.rs index 01b1120..82216a2 100644 --- a/src/json_contains.rs +++ b/src/json_contains.rs @@ -60,6 +60,24 @@ impl ScalarUDFImpl for JsonContains { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } impl InvokeResult for BooleanArray { diff --git a/src/json_get.rs b/src/json_get.rs index 4c7f9bc..d887d82 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -62,6 +62,24 @@ impl ScalarUDFImpl for JsonGet { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } impl InvokeResult for JsonUnion { diff --git a/src/json_get_array.rs b/src/json_get_array.rs index 88c680d..5b74a2f 100644 --- a/src/json_get_array.rs +++ b/src/json_get_array.rs @@ -64,6 +64,24 @@ impl ScalarUDFImpl for JsonGetArray { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } #[derive(Debug)] diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs index 17be9b0..2a24420 100644 --- a/src/json_get_bool.rs +++ b/src/json_get_bool.rs @@ -55,6 +55,24 @@ impl ScalarUDFImpl for JsonGetBool { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } fn jiter_json_get_bool(json_data: Option<&str>, path: &[JsonPath]) -> Result { diff --git a/src/json_get_float.rs b/src/json_get_float.rs index aff252b..05feabb 100644 --- a/src/json_get_float.rs +++ b/src/json_get_float.rs @@ -56,6 +56,24 @@ impl ScalarUDFImpl for JsonGetFloat { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } impl InvokeResult for Float64Array { diff --git a/src/json_get_int.rs b/src/json_get_int.rs index 5788957..1fa2b04 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -56,6 +56,24 @@ impl ScalarUDFImpl for JsonGetInt { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } impl InvokeResult for Int64Array { diff --git a/src/json_get_json.rs b/src/json_get_json.rs index 5907b2b..4b3b678 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -54,6 +54,24 @@ impl ScalarUDFImpl for JsonGetJson { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result { diff --git a/src/json_get_str.rs b/src/json_get_str.rs index 658f1e3..89c54fa 100644 --- a/src/json_get_str.rs +++ b/src/json_get_str.rs @@ -55,6 +55,24 @@ impl ScalarUDFImpl for JsonGetStr { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } fn jiter_json_get_str(json_data: Option<&str>, path: &[JsonPath]) -> Result { diff --git a/src/json_length.rs b/src/json_length.rs index 8bd657d..a2357ca 100644 --- a/src/json_length.rs +++ b/src/json_length.rs @@ -56,6 +56,24 @@ impl ScalarUDFImpl for JsonLength { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } impl InvokeResult for UInt64Array { diff --git a/src/json_object_keys.rs b/src/json_object_keys.rs index 8ea040d..9e3f209 100644 --- a/src/json_object_keys.rs +++ b/src/json_object_keys.rs @@ -60,6 +60,24 @@ impl ScalarUDFImpl for JsonObjectKeys { fn aliases(&self) -> &[String] { &self.aliases } + + fn placement( + &self, + args: &[datafusion::logical_expr::ExpressionPlacement], + ) -> datafusion::logical_expr::ExpressionPlacement { + // If the first argument is a column and the remaining arguments are literals (a path) + // then we can push this UDF down to the leaf nodes. + if args.len() >= 2 + && matches!(args[0], datafusion::logical_expr::ExpressionPlacement::Column) + && args[1..] + .iter() + .all(|arg| matches!(arg, datafusion::logical_expr::ExpressionPlacement::Literal)) + { + datafusion::logical_expr::ExpressionPlacement::MoveTowardsLeafNodes + } else { + datafusion::logical_expr::ExpressionPlacement::KeepInPlace + } + } } /// Struct used to build a `ListArray` from the result of `jiter_json_object_keys`. diff --git a/src/rewrite.rs b/src/rewrite.rs index a0e87d5..6aa51d3 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -38,7 +38,7 @@ fn optimise_json_get_cast(cast: &Cast) -> Option> { if scalar_func.func.name() != "json_get" { return None; } - let func = match &cast.data_type { + let func = match cast.field.data_type() { DataType::Boolean => crate::json_get_bool::json_get_bool_udf(), DataType::Float64 | DataType::Float32 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { crate::json_get_float::json_get_float_udf() diff --git a/tests/main.rs b/tests/main.rs index c340fcd..46e9b73 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -163,14 +163,27 @@ async fn test_json_get_array_with_path() { #[tokio::test] async fn test_json_get_equals() { - let e = run_query(r"select name, json_get(json_data, 'foo')='abc' from test") + // union comparison now works thanks to the union coercions upport in datafusion + // (previously failed with "Cannot infer common argument type for comparison operation Union") + // see https://github.com/apache/datafusion/issues/10180 + let batches = run_query(r"select name, json_get(json_data, 'foo')='abc' from test") .await - .unwrap_err(); + .unwrap(); - // see https://github.com/apache/datafusion/issues/10180 - assert!(e - .to_string() - .starts_with("Error during planning: Cannot infer common argument type for comparison operation Union")); + let expected = [ + "+------------------+----------------------------------------------------+", + r#"| name | json_get(test.json_data,Utf8("foo")) = Utf8("abc") |"#, + "+------------------+----------------------------------------------------+", + "| object_foo | true |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+----------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); } #[tokio::test] @@ -762,9 +775,9 @@ async fn test_plan_json_get_cte() { select name, json_get(j, 0) v from t "; let expected = [ - "Projection: t.name, json_get(t.j, Int64(0)) AS v", + "Projection: t.name, __datafusion_extracted_1 AS v", " SubqueryAlias: t", - " Projection: test.name, json_get(test.json_data, Utf8(\"foo\")) AS j", + " Projection: test.name, json_get(json_get(test.json_data, Utf8(\"foo\")), Int64(0)) AS __datafusion_extracted_1", " TableScan: test projection=[name, json_data]", ]; @@ -1255,7 +1268,7 @@ async fn test_plan_double_arrow_double_nested_cast() { // NB: json_as_text(..)::int is NOT the same as `json_get_int(..)`, hence the cast is not rewritten let expected = [ - "Projection: CAST(json_as_text(test.json_data, Utf8(\"foo\"), Int64(0)) AS json_data ->> 'foo' ->> 0 AS Int32)", + "Projection: CAST(json_as_text(test.json_data, Utf8(\"foo\"), Int64(0)) AS Int32) AS json_data ->> 'foo' ->> 0", " TableScan: test projection=[json_data]", ];