Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::str::Utf8Error;
use std::sync::Arc;

Expand All @@ -16,6 +17,13 @@ use crate::common_union::{
is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL,
};

/// Field metadata that marks a `Utf8` column/field as containing raw JSON.
/// Downstream consumers (e.g. the rewrite layer, other UDFs) use this to
/// recognize JSON-bearing string columns.
pub fn is_json_metadata() -> HashMap<String, String> {
HashMap::from_iter(vec![("is_json".to_string(), "true".to_string())])
Comment thread
adriangb marked this conversation as resolved.
Outdated
}

/// General implementation of `ScalarUDFImpl::return_type`.
///
/// # Arguments
Expand Down
9 changes: 4 additions & 5 deletions src/common_union.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::sync::{Arc, LazyLock, OnceLock};

use datafusion::arrow::array::{
Expand All @@ -9,6 +8,8 @@ use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
use datafusion::arrow::error::ArrowError;
use datafusion::common::ScalarValue;

use crate::common::is_json_metadata;

pub fn is_json_union(data_type: &DataType) -> bool {
match data_type {
DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(),
Expand Down Expand Up @@ -161,8 +162,6 @@ fn union_fields() -> UnionFields {
static FIELDS: OnceLock<UnionFields> = OnceLock::new();
FIELDS
.get_or_init(|| {
let json_metadata: HashMap<String, String> =
HashMap::from_iter(vec![("is_json".to_string(), "true".to_string())]);
UnionFields::from_iter([
(TYPE_ID_NULL, Arc::new(Field::new("null", DataType::Null, true))),
(TYPE_ID_BOOL, Arc::new(Field::new("bool", DataType::Boolean, false))),
Expand All @@ -171,11 +170,11 @@ fn union_fields() -> UnionFields {
(TYPE_ID_STR, Arc::new(Field::new("str", DataType::Utf8, false))),
(
TYPE_ID_ARRAY,
Arc::new(Field::new("array", DataType::Utf8, false).with_metadata(json_metadata.clone())),
Arc::new(Field::new("array", DataType::Utf8, false).with_metadata(is_json_metadata())),
),
(
TYPE_ID_OBJECT,
Arc::new(Field::new("object", DataType::Utf8, false).with_metadata(json_metadata.clone())),
Arc::new(Field::new("object", DataType::Utf8, false).with_metadata(is_json_metadata())),
),
])
})
Expand Down
24 changes: 11 additions & 13 deletions src/json_get_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@ use std::any::Any;
use std::sync::Arc;

use datafusion::arrow::array::{ArrayRef, ListBuilder, StringBuilder};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use jiter::Peek;

use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath};
use crate::common::{
get_err, invoke, is_json_metadata, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath,
};
use crate::common_macros::make_udf_function;

fn list_item_field() -> Field {
Field::new("item", DataType::Utf8, true).with_metadata(is_json_metadata())
}

make_udf_function!(
JsonGetArray,
json_get_array,
Expand Down Expand Up @@ -46,15 +52,7 @@ impl ScalarUDFImpl for JsonGetArray {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
return_type_check(
arg_types,
self.name(),
DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new(
"item",
DataType::Utf8,
true,
))),
)
return_type_check(arg_types, self.name(), DataType::List(Arc::new(list_item_field())))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
Expand Down Expand Up @@ -96,7 +94,7 @@ impl InvokeResult for BuildArrayList {

fn builder(capacity: usize) -> Self::Builder {
let values_builder = StringBuilder::new();
ListBuilder::with_capacity(values_builder, capacity)
ListBuilder::with_capacity(values_builder, capacity).with_field(list_item_field())
}

fn append_value(builder: &mut Self::Builder, value: Option<Self::Item>) {
Expand All @@ -108,7 +106,7 @@ impl InvokeResult for BuildArrayList {
}

fn scalar(value: Option<Self::Item>) -> ScalarValue {
let mut builder = ListBuilder::new(StringBuilder::new());
let mut builder = ListBuilder::new(StringBuilder::new()).with_field(list_item_field());

if let Some(array_items) = value {
for item in array_items {
Expand Down
17 changes: 14 additions & 3 deletions src/json_get_json.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::any::Any;
use std::sync::Arc;

use datafusion::arrow::array::StringArray;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::{DataType, Field, FieldRef};
use datafusion::common::Result as DataFusionResult;
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use datafusion::logical_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};

use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
use crate::common::{get_err, invoke, is_json_metadata, jiter_json_find, return_type_check, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -47,6 +50,14 @@ impl ScalarUDFImpl for JsonGetJson {
return_type_check(arg_types, self.name(), DataType::Utf8)
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> DataFusionResult<FieldRef> {
let arg_types: Vec<DataType> = args.arg_fields.iter().map(|f| f.data_type().clone()).collect();
let return_type = self.return_type(&arg_types)?;
Ok(Arc::new(
Field::new(self.name(), return_type, true).with_metadata(is_json_metadata()),
))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
invoke::<StringArray>(&args.args, jiter_json_get_json)
}
Expand Down
34 changes: 34 additions & 0 deletions tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,31 @@ async fn test_json_get_array_with_path() {
assert_eq!(value_repr, "[1, 2, 3]");
}

#[tokio::test]
async fn test_json_get_array_inner_field_is_json_metadata() {
let sql = r#"select json_get_array('[{"a": 1}, {"b": 2}]') as v"#;
let batches = run_query(sql).await.unwrap();
let schema = batches[0].schema();
let field = schema.field(0);
let DataType::List(inner_field) = field.data_type() else {
panic!("expected List, got {:?}", field.data_type());
};
assert_eq!(inner_field.metadata().get("is_json").map(String::as_str), Some("true"));

let array_field = batches[0]
.column(0)
.as_any()
.downcast_ref::<datafusion::arrow::array::ListArray>()
.unwrap();
let DataType::List(produced_inner) = array_field.data_type() else {
panic!("expected List in produced array");
};
assert_eq!(
produced_inner.metadata().get("is_json").map(String::as_str),
Some("true")
);
}

#[tokio::test]
async fn test_json_get_equals() {
let e = run_query(r"select name, json_get(json_data, 'foo')='abc' from test")
Expand Down Expand Up @@ -411,6 +436,15 @@ async fn test_json_get_json_float() {
assert_eq!(display_val(batches).await, (DataType::Utf8, "4.2e-1".to_string()));
}

#[tokio::test]
async fn test_json_get_json_is_json_metadata() {
let sql = r#"select json_get_json('{"x": [1, 2]}', 'x') as v"#;
let batches = run_query(sql).await.unwrap();
let schema = batches[0].schema();
let field = schema.field(0);
assert_eq!(field.metadata().get("is_json").map(String::as_str), Some("true"));
}

#[tokio::test]
async fn test_json_length_array() {
let sql = "select json_length('[1, 2, 3]')";
Expand Down
Loading