diff --git a/core/src/util/supported_functions.rs b/core/src/util/supported_functions.rs index c6338cd6..cec63873 100644 --- a/core/src/util/supported_functions.rs +++ b/core/src/util/supported_functions.rs @@ -28,20 +28,24 @@ pub fn contains_unsupported_functions( plan: &LogicalPlan, sup: &FunctionSupport, ) -> Result { - plan.exists(|plan| { - Ok(plan.expressions().into_iter().any(|expr| { - let mut found_unsupported = false; - let _ = expr.apply(|expr| { + let mut found_unsupported = false; + plan.apply_with_subqueries(|plan| { + for expr in plan.expressions() { + expr.apply(|expr| { if sup.supports(expr) { Ok(TreeNodeRecursion::Continue) } else { found_unsupported = true; Ok(TreeNodeRecursion::Stop) } - }); - found_unsupported - })) - }) + })?; + if found_unsupported { + return Ok(TreeNodeRecursion::Stop); + } + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(found_unsupported) } #[derive(Clone, Debug)] @@ -163,3 +167,193 @@ impl FunctionRestriction { } } } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::logical_expr::builder::LogicalTableSource; + use datafusion::logical_expr::expr::ScalarFunction; + use datafusion::logical_expr::{create_udf, ColumnarValue, LogicalPlanBuilder, Subquery}; + use datafusion::prelude::col; + use std::sync::Arc; + + fn stub_udf(name: &str) -> Arc { + Arc::new(create_udf( + name, + vec![DataType::Utf8], + DataType::Utf8, + datafusion::logical_expr::Volatility::Immutable, + Arc::new(|args: &[ColumnarValue]| Ok(args[0].clone())), + )) + } + + fn deny_support(names: &[&str]) -> FunctionSupport { + FunctionSupport::new( + Some(FunctionRestriction::Deny( + names.iter().map(|s| s.to_string()).collect(), + )), + None, + None, + ) + } + + fn scan_plan(table: &str) -> LogicalPlan { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, true), + ])); + let source = Arc::new(LogicalTableSource::new(schema)) + as Arc; + LogicalPlanBuilder::scan(table, source, None) + .expect("scan") + .build() + .expect("build") + } + + #[test] + fn detects_denied_function_in_top_level_projection() { + let udf = stub_udf("denied_fn"); + let plan = LogicalPlanBuilder::from(scan_plan("t")) + .project(vec![Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![col("val")], + ))]) + .expect("project") + .build() + .expect("build"); + + let sup = deny_support(&["denied_fn"]); + assert!( + contains_unsupported_functions(&plan, &sup).expect("check"), + "should detect denied function in top-level projection" + ); + } + + #[test] + fn allows_plan_without_denied_functions() { + let udf = stub_udf("allowed_fn"); + let plan = LogicalPlanBuilder::from(scan_plan("t")) + .project(vec![Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![col("val")], + ))]) + .expect("project") + .build() + .expect("build"); + + let sup = deny_support(&["denied_fn"]); + assert!( + !contains_unsupported_functions(&plan, &sup).expect("check"), + "should allow plan with only non-denied functions" + ); + } + + #[test] + fn detects_denied_function_inside_in_subquery() { + let udf = stub_udf("denied_fn"); + + // Build subquery: SELECT denied_fn(val) FROM inner_t + let subquery_plan = LogicalPlanBuilder::from(scan_plan("inner_t")) + .project(vec![Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![col("val")], + )) + .alias("result")]) + .expect("project") + .build() + .expect("build"); + + // Build outer: SELECT id FROM t WHERE id IN (subquery) + let outer = LogicalPlanBuilder::from(scan_plan("t")) + .filter(Expr::InSubquery( + datafusion::logical_expr::expr::InSubquery::new( + Box::new(col("id")), + Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Default::default(), + }, + false, + ), + )) + .expect("filter") + .build() + .expect("build"); + + let sup = deny_support(&["denied_fn"]); + assert!( + contains_unsupported_functions(&outer, &sup).expect("check"), + "should detect denied function inside IN subquery" + ); + } + + #[test] + fn detects_denied_function_inside_scalar_subquery() { + let udf = stub_udf("denied_fn"); + + // Build scalar subquery: SELECT denied_fn(val) FROM inner_t + let subquery_plan = LogicalPlanBuilder::from(scan_plan("inner_t")) + .project(vec![Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![col("val")], + )) + .alias("result")]) + .expect("project") + .build() + .expect("build"); + + // Build outer: SELECT id FROM t WHERE id = (scalar subquery) + let outer = LogicalPlanBuilder::from(scan_plan("t")) + .filter(col("id").eq(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Default::default(), + }))) + .expect("filter") + .build() + .expect("build"); + + let sup = deny_support(&["denied_fn"]); + assert!( + contains_unsupported_functions(&outer, &sup).expect("check"), + "should detect denied function inside scalar subquery" + ); + } + + #[test] + fn detects_denied_function_inside_exists_subquery() { + let udf = stub_udf("denied_fn"); + + // Build subquery: SELECT denied_fn(val) FROM inner_t + let subquery_plan = LogicalPlanBuilder::from(scan_plan("inner_t")) + .project(vec![Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![col("val")], + )) + .alias("result")]) + .expect("project") + .build() + .expect("build"); + + // Build outer: SELECT id FROM t WHERE EXISTS (subquery) + let outer = LogicalPlanBuilder::from(scan_plan("t")) + .filter(Expr::Exists(datafusion::logical_expr::expr::Exists::new( + Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Default::default(), + }, + false, + ))) + .expect("filter") + .build() + .expect("build"); + + let sup = deny_support(&["denied_fn"]); + assert!( + contains_unsupported_functions(&outer, &sup).expect("check"), + "should detect denied function inside EXISTS subquery" + ); + } +}