Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 257 additions & 32 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ use datafusion::functions_aggregate::min_max::max_udaf;
use datafusion::functions_aggregate::min_max::min_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion::physical_plan::windows::BoundedWindowAggExec;
use datafusion::physical_plan::InputOrderMode;
use datafusion::physical_plan::windows::WindowAggExec;
use datafusion::{
arrow::{compute::SortOptions, datatypes::SchemaRef},
common::DataFusionError,
Expand Down Expand Up @@ -1857,16 +1856,76 @@ impl PhysicalPlanner {
})
.collect();

let window_agg = Arc::new(BoundedWindowAggExec::try_new(
window_expr?,
let window_expr = window_expr?;

// Always use the non-streaming `WindowAggExec`. `BoundedWindowAggExec`
// (DataFusion's streaming variant) invokes `retract_batch` on the UDAF
// for sliding frames like `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`,
// and Comet's Spark-compatible aggregates (`SumDecimal`, `SumInteger`,
// `AvgDecimal`, `Avg`) don't implement retract — they'd fail at runtime
// with "Aggregate can not be used as a sliding accumulator". It also
// sidesteps the "Can not execute X in a streaming fashion" error for
// PERCENT_RANK / CUME_DIST / NTILE which report !uses_bounded_memory().
// This matches Spark's non-streaming `WindowExec` semantics as well.
let window_agg: Arc<dyn ExecutionPlan> = Arc::new(WindowAggExec::try_new(
window_expr,
Arc::clone(&child.native_plan),
InputOrderMode::Sorted,
!partition_exprs.is_empty(),
)?);

// DataFusion's window functions don't always return the same Arrow
// type that Spark expects (e.g. `row_number` returns UInt64 while
// Spark expects Int32). If any window expression carries a
// `result_type` that differs from the actual output type, wrap the
// aggregate in a projection that casts the mismatched columns.
let final_plan: Arc<dyn ExecutionPlan> = {
let agg_schema = window_agg.schema();
let input_field_count = input_schema.fields().len();
let needs_cast = wnd.window_expr.iter().enumerate().any(|(i, w)| {
w.result_type
.as_ref()
.map(|t| {
let expected = to_arrow_datatype(t);
let actual = agg_schema.field(input_field_count + i).data_type();
&expected != actual
})
.unwrap_or(false)
});

if needs_cast {
let mut proj_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
Vec::with_capacity(agg_schema.fields().len());
for (idx, field) in agg_schema.fields().iter().enumerate() {
let col: Arc<dyn PhysicalExpr> =
Arc::new(Column::new(field.name(), idx));
let expr: Arc<dyn PhysicalExpr> = if idx >= input_field_count {
let w = &wnd.window_expr[idx - input_field_count];
match &w.result_type {
Some(t) => {
let expected = to_arrow_datatype(t);
if &expected != field.data_type() {
Arc::new(CastExpr::new(col, expected, None))
} else {
col
}
}
None => col,
}
} else {
col
};
proj_exprs.push((expr, field.name().to_string()));
}
Arc::new(ProjectionExec::try_new(proj_exprs, window_agg)?)
} else {
window_agg
}
};

Ok((
scans,
shuffle_scans,
Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])),
Arc::new(SparkPlan::new(spark_plan.plan_id, final_plan, vec![child])),
))
}
OpStruct::ShuffleScan(scan) => {
Expand Down Expand Up @@ -2393,6 +2452,7 @@ impl PhysicalPlanner {
partition_by: &[Arc<dyn PhysicalExpr>],
sort_exprs: &[PhysicalSortExpr],
) -> Result<Arc<dyn WindowExpr>, ExecutionError> {
let window_func: WindowFunctionDefinition;
let window_func_name: String;
let window_args: Vec<Arc<dyn PhysicalExpr>>;
if let Some(func) = &spark_expr.built_in_window_function {
Expand All @@ -2404,6 +2464,13 @@ impl PhysicalPlanner {
.iter()
.map(|expr| self.create_expr(expr, Arc::clone(&input_schema)))
.collect::<Result<Vec<_>, ExecutionError>>()?;
window_func =
self.find_df_window_function(&window_func_name)
.ok_or_else(|| {
GeneralError(format!(
"{window_func_name} not supported for window function"
))
})?;
}
other => {
return Err(GeneralError(format!(
Expand All @@ -2412,24 +2479,32 @@ impl PhysicalPlanner {
}
};
} else if let Some(agg_func) = &spark_expr.agg_func {
let result = self.process_agg_func(agg_func, Arc::clone(&input_schema))?;
window_func_name = result.0;
window_args = result.1;
// Is the frame ever-expanding (start = UnboundedPreceding)? When it is,
// DataFusion uses `PlainAggregateWindowExpr` which does not call
// `retract_batch`, so we can safely use Comet's Spark-compatible
// UDAFs (SumDecimal/SumInteger/AvgDecimal/Avg). Otherwise it uses
// `SlidingAggregateWindowExpr` which requires retract — Comet's UDAFs
// don't implement it, so the caller must fall back to DataFusion's
// built-ins (which do).
let is_ever_expanding = spark_expr
.spec
.as_ref()
.and_then(|s| s.frame_specification.as_ref())
.and_then(|f| f.lower_bound.as_ref())
.and_then(|lb| lb.lower_frame_bound_struct.as_ref())
.map(|inner| matches!(inner, LowerFrameBoundStruct::UnboundedPreceding(_)))
.unwrap_or(true);
let (func, args) =
self.process_agg_func(agg_func, Arc::clone(&input_schema), is_ever_expanding)?;
window_func_name = func.name().to_string();
window_args = args;
window_func = func;
} else {
return Err(GeneralError(
"Both func and agg_func are not set".to_string(),
));
}

let window_func = match self.find_df_window_function(&window_func_name) {
Some(f) => f,
_ => {
return Err(GeneralError(format!(
"{window_func_name} not supported for window function"
)))
}
};

let spark_window_frame = match spark_expr
.spec
.as_ref()
Expand Down Expand Up @@ -2474,7 +2549,11 @@ impl PhysicalPlanner {
Some(offset_value as u64),
)),
WindowFrameUnits::Range => {
WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value)))
let scalar = match offset.range_offset.as_ref() {
Some(lit) => numeric_literal_to_scalar(lit)?,
None => ScalarValue::Int64(Some(offset_value)),
};
WindowFrameBound::Preceding(scalar)
}
WindowFrameUnits::Groups => {
return Err(GeneralError(
Expand Down Expand Up @@ -2520,7 +2599,11 @@ impl PhysicalPlanner {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
WindowFrameUnits::Range => {
WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset)))
let scalar = match offset.range_offset.as_ref() {
Some(lit) => numeric_literal_to_scalar(lit)?,
None => ScalarValue::Int64(Some(offset.offset)),
};
WindowFrameBound::Following(scalar)
}
WindowFrameUnits::Groups => {
return Err(GeneralError(
Expand Down Expand Up @@ -2564,35 +2647,121 @@ impl PhysicalPlanner {
&self,
agg_func: &AggExpr,
schema: SchemaRef,
) -> Result<(String, Vec<Arc<dyn PhysicalExpr>>), ExecutionError> {
is_ever_expanding: bool,
) -> Result<(WindowFunctionDefinition, Vec<Arc<dyn PhysicalExpr>>), ExecutionError> {
// Wrap a freshly-constructed AggregateUDF impl as a WindowFunctionDefinition.
fn udaf<U: datafusion::logical_expr::AggregateUDFImpl + 'static>(
udaf: U,
) -> WindowFunctionDefinition {
WindowFunctionDefinition::AggregateUDF(Arc::new(AggregateUDF::new_from_impl(udaf)))
}

// Resolve a window-capable function by name via the session registry, returning
// a clean "X not supported for window function" error if missing.
let by_name = |name: &str| -> Result<WindowFunctionDefinition, ExecutionError> {
self.find_df_window_function(name)
.ok_or_else(|| GeneralError(format!("{name} not supported for window function")))
};

match &agg_func.expr_struct {
Some(AggExprStruct::Count(expr)) => {
let children = expr
.children
.iter()
.map(|child| self.create_expr(child, Arc::clone(&schema)))
.collect::<Result<Vec<_>, _>>()?;
Ok(("count".to_string(), children))
Ok((by_name("count")?, children))
}
Some(AggExprStruct::Min(expr)) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
Ok(("min".to_string(), vec![child]))
Ok((by_name("min")?, vec![child]))
}
Some(AggExprStruct::Max(expr)) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
Ok(("max".to_string(), vec![child]))
Ok((by_name("max")?, vec![child]))
}
Some(AggExprStruct::Sum(expr)) => {
// For ever-expanding frames, use Comet's Spark-compatible Sum UDAFs
// (SumDecimal / SumInteger) which enforce Spark overflow semantics.
// For sliding frames, those UDAFs can't be used (no retract_batch),
// so delegate to DataFusion's built-in `sum`, which supports retract
// but doesn't enforce Spark's decimal precision overflow-to-NULL.
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let arrow_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let datatype = child.data_type(&schema)?;

let child = if datatype != arrow_type {
Arc::new(CastExpr::new(child, arrow_type.clone(), None))
} else {
child
};
Ok(("sum".to_string(), vec![child]))
match arrow_type {
DataType::Decimal128(_, _) if is_ever_expanding => {
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
let func = SumDecimal::try_new(
arrow_type,
eval_mode,
agg_func.expr_id,
Arc::clone(&self.query_context_registry),
)?;
Ok((udaf(func), vec![child]))
}
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
if is_ever_expanding =>
{
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
let func = SumInteger::try_new(arrow_type, eval_mode)?;
Ok((udaf(func), vec![child]))
}
_ => {
let actual = child.data_type(&schema)?;
let child: Arc<dyn PhysicalExpr> = if actual != arrow_type {
Arc::new(CastExpr::new(child, arrow_type, None))
} else {
child
};
Ok((by_name("sum")?, vec![child]))
}
}
}
Some(AggExprStruct::Avg(expr)) => {
// Same rule as Sum: Comet's Avg/AvgDecimal for ever-expanding frames,
// DataFusion's `avg` for sliding (retract-capable).
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
match datatype {
DataType::Decimal128(_, _) if is_ever_expanding => {
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
let func = AvgDecimal::new(
datatype,
input_datatype,
eval_mode,
agg_func.expr_id,
Arc::clone(&self.query_context_registry),
);
Ok((udaf(func), vec![child]))
}
_ if is_ever_expanding => {
let child: Arc<dyn PhysicalExpr> =
Arc::new(CastExpr::new(child, DataType::Float64, None));
let func = Avg::new("avg", DataType::Float64);
Ok((udaf(func), vec![child]))
}
_ => {
// Sliding frame — DataFusion's built-in `avg` handles retract.
// Cast non-decimal input to Float64 to match Spark's Avg result type.
let child: Arc<dyn PhysicalExpr> = match datatype {
DataType::Decimal128(_, _) => child,
_ => Arc::new(CastExpr::new(child, DataType::Float64, None)),
};
Ok((by_name("avg")?, vec![child]))
}
}
}
Some(AggExprStruct::First(expr)) => {
// Spark's FIRST_VALUE → DataFusion's `first_value` UDAF. The UDAF honors
// ignore-nulls via the WindowExpr-level `ignore_nulls` flag.
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
Ok((by_name("first_value")?, vec![child]))
}
Some(AggExprStruct::Last(expr)) => {
// Spark's LAST_VALUE → DataFusion's `last_value` UDAF.
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
Ok((by_name("last_value")?, vec![child]))
}
other => Err(GeneralError(format!(
"{other:?} not supported for window function"
Expand Down Expand Up @@ -2910,6 +3079,62 @@ fn expr_to_columns(
Ok((left_field_indices, right_field_indices))
}

/// Convert a Spark numeric Literal proto into a `ScalarValue` whose data type
/// matches the literal's declared type. Used for RANGE window frame offsets,
/// where the offset's type must match the ORDER BY column's type. Only numeric
/// types are supported; the Scala side rejects non-numeric RANGE offsets before
/// reaching here.
fn numeric_literal_to_scalar(
lit: &spark_expression::Literal,
) -> Result<ScalarValue, ExecutionError> {
let data_type = to_arrow_datatype(lit.datatype.as_ref().ok_or_else(|| {
GeneralError("RANGE frame offset literal is missing datatype".to_string())
})?);

if lit.is_null {
return Err(GeneralError(
"RANGE frame offset must not be null".to_string(),
));
}

let value = lit
.value
.as_ref()
.ok_or_else(|| GeneralError("RANGE frame offset literal has no value".to_string()))?;

let scalar = match value {
Value::ByteVal(v) => ScalarValue::Int8(Some(*v as i8)),
Value::ShortVal(v) => ScalarValue::Int16(Some(*v as i16)),
Value::IntVal(v) => ScalarValue::Int32(Some(*v)),
Value::LongVal(v) => ScalarValue::Int64(Some(*v)),
Value::FloatVal(v) => ScalarValue::Float32(Some(*v)),
Value::DoubleVal(v) => ScalarValue::Float64(Some(*v)),
Value::DecimalVal(bytes) => {
let big_integer = BigInt::from_signed_bytes_be(bytes);
let integer = big_integer.to_i128().ok_or_else(|| {
GeneralError(format!(
"Cannot parse {big_integer:?} as i128 for Decimal RANGE frame offset"
))
})?;
match data_type {
DataType::Decimal128(p, s) => ScalarValue::Decimal128(Some(integer), p, s),
ref dt => {
return Err(GeneralError(format!(
"Decimal RANGE frame offset has non-Decimal128 datatype: {dt:?}"
)))
}
}
}
other => {
return Err(GeneralError(format!(
"Unsupported value variant for RANGE frame offset: {other:?}"
)))
}
};

Ok(scalar)
}

/// A physical join filter rewritter which rewrites the column indices in the expression
/// to use the new column indices. See `rewrite_physical_expr`.
struct JoinFilterRewriter<'a> {
Expand Down
Loading
Loading