Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 202 additions & 8 deletions core/src/util/supported_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,24 @@ pub fn contains_unsupported_functions(
plan: &LogicalPlan,
sup: &FunctionSupport,
) -> Result<bool, DataFusionError> {
plan.exists(|plan| {
Ok(plan.expressions().into_iter().any(|expr| {
let mut found_unsupported = false;
let _ = expr.apply(|expr| {
let mut found_unsupported = false;
plan.apply_with_subqueries(|plan| {
for expr in plan.expressions() {
expr.apply(|expr| {
if sup.supports(expr) {
Ok(TreeNodeRecursion::Continue)
} else {
found_unsupported = true;
Ok(TreeNodeRecursion::Stop)
}
});
found_unsupported
}))
})
})?;
if found_unsupported {
return Ok(TreeNodeRecursion::Stop);
}
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(found_unsupported)
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -163,3 +167,193 @@ impl FunctionRestriction {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::logical_expr::builder::LogicalTableSource;
use datafusion::logical_expr::expr::ScalarFunction;
use datafusion::logical_expr::{create_udf, ColumnarValue, LogicalPlanBuilder, Subquery};
use datafusion::prelude::col;
use std::sync::Arc;

fn stub_udf(name: &str) -> Arc<ScalarUDF> {
Arc::new(create_udf(
name,
vec![DataType::Utf8],
DataType::Utf8,
datafusion::logical_expr::Volatility::Immutable,
Arc::new(|args: &[ColumnarValue]| Ok(args[0].clone())),
))
}

fn deny_support(names: &[&str]) -> FunctionSupport {
FunctionSupport::new(
Some(FunctionRestriction::Deny(
names.iter().map(|s| s.to_string()).collect(),
)),
None,
None,
)
}

fn scan_plan(table: &str) -> LogicalPlan {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Utf8, true),
]));
let source = Arc::new(LogicalTableSource::new(schema))
as Arc<dyn datafusion::logical_expr::TableSource>;
LogicalPlanBuilder::scan(table, source, None)
.expect("scan")
.build()
.expect("build")
}

#[test]
fn detects_denied_function_in_top_level_projection() {
let udf = stub_udf("denied_fn");
let plan = LogicalPlanBuilder::from(scan_plan("t"))
.project(vec![Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![col("val")],
))])
.expect("project")
.build()
.expect("build");

let sup = deny_support(&["denied_fn"]);
assert!(
contains_unsupported_functions(&plan, &sup).expect("check"),
"should detect denied function in top-level projection"
);
}

#[test]
fn allows_plan_without_denied_functions() {
let udf = stub_udf("allowed_fn");
let plan = LogicalPlanBuilder::from(scan_plan("t"))
.project(vec![Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![col("val")],
))])
.expect("project")
.build()
.expect("build");

let sup = deny_support(&["denied_fn"]);
assert!(
!contains_unsupported_functions(&plan, &sup).expect("check"),
"should allow plan with only non-denied functions"
);
}

#[test]
fn detects_denied_function_inside_in_subquery() {
let udf = stub_udf("denied_fn");

// Build subquery: SELECT denied_fn(val) FROM inner_t
let subquery_plan = LogicalPlanBuilder::from(scan_plan("inner_t"))
.project(vec![Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![col("val")],
))
.alias("result")])
.expect("project")
.build()
.expect("build");

// Build outer: SELECT id FROM t WHERE id IN (subquery)
let outer = LogicalPlanBuilder::from(scan_plan("t"))
.filter(Expr::InSubquery(
datafusion::logical_expr::expr::InSubquery::new(
Box::new(col("id")),
Subquery {
subquery: Arc::new(subquery_plan),
outer_ref_columns: vec![],
spans: Default::default(),
},
false,
),
))
.expect("filter")
.build()
.expect("build");

let sup = deny_support(&["denied_fn"]);
assert!(
contains_unsupported_functions(&outer, &sup).expect("check"),
"should detect denied function inside IN subquery"
);
}

#[test]
fn detects_denied_function_inside_scalar_subquery() {
let udf = stub_udf("denied_fn");

// Build scalar subquery: SELECT denied_fn(val) FROM inner_t
let subquery_plan = LogicalPlanBuilder::from(scan_plan("inner_t"))
.project(vec![Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![col("val")],
))
.alias("result")])
.expect("project")
.build()
.expect("build");

// Build outer: SELECT id FROM t WHERE id = (scalar subquery)
let outer = LogicalPlanBuilder::from(scan_plan("t"))
.filter(col("id").eq(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(subquery_plan),
outer_ref_columns: vec![],
spans: Default::default(),
})))
.expect("filter")
.build()
.expect("build");

let sup = deny_support(&["denied_fn"]);
assert!(
contains_unsupported_functions(&outer, &sup).expect("check"),
"should detect denied function inside scalar subquery"
);
}

#[test]
fn detects_denied_function_inside_exists_subquery() {
let udf = stub_udf("denied_fn");

// Build subquery: SELECT denied_fn(val) FROM inner_t
let subquery_plan = LogicalPlanBuilder::from(scan_plan("inner_t"))
.project(vec![Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![col("val")],
))
.alias("result")])
.expect("project")
.build()
.expect("build");

// Build outer: SELECT id FROM t WHERE EXISTS (subquery)
let outer = LogicalPlanBuilder::from(scan_plan("t"))
.filter(Expr::Exists(datafusion::logical_expr::expr::Exists::new(
Subquery {
subquery: Arc::new(subquery_plan),
outer_ref_columns: vec![],
spans: Default::default(),
},
false,
)))
.expect("filter")
.build()
.expect("build");

let sup = deny_support(&["denied_fn"]);
assert!(
contains_unsupported_functions(&outer, &sup).expect("check"),
"should detect denied function inside EXISTS subquery"
);
}
}
Loading