Skip to content

Commit

Permalink
fix: Support Substrait's compound names also for window functions (ap…
Browse files Browse the repository at this point in the history
…ache#11163)

* simplify and deduplicate scalar/aggregate/window function handling

* simplify window function handling and error out faster if function is not found

* fix window function name

* simplify scalar function handling

* fix

* add a test and fix consuming bound types
  • Loading branch information
Blizzara authored and findepi committed Jul 16, 2024
1 parent 0c311e2 commit 1172975
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 157 deletions.
239 changes: 85 additions & 154 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use datafusion::arrow::datatypes::{
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err,
substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
};
use substrait::proto::expression::literal::IntervalDayToSecond;
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
Expand All @@ -30,8 +31,7 @@ use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case,
EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF,
Values,
EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values,
};

use datafusion::logical_expr::{
Expand All @@ -57,7 +57,7 @@ use substrait::proto::{
reference_segment::ReferenceType::StructField,
window_function::bound as SubstraitBound,
window_function::bound::Kind as BoundKind, window_function::Bound,
MaskExpression, RexType,
window_function::BoundsType, MaskExpression, RexType,
},
extensions::simple_extension_declaration::MappingType,
function_argument::ArgType,
Expand All @@ -71,7 +71,6 @@ use substrait::proto::{
use substrait::proto::{FunctionArgument, SortField};

use datafusion::arrow::array::GenericListArray;
use datafusion::common::plan_err;
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
use std::collections::HashMap;
Expand All @@ -89,12 +88,6 @@ use crate::variation_const::{
UNSIGNED_INTEGER_TYPE_VARIATION_REF,
};

enum ScalarFunctionType {
Op(Operator),
Expr(BuiltinExprBuilder),
Udf(Arc<ScalarUDF>),
}

pub fn name_to_op(name: &str) -> Result<Operator> {
match name {
"equal" => Ok(Operator::Eq),
Expand Down Expand Up @@ -128,28 +121,6 @@ pub fn name_to_op(name: &str) -> Result<Operator> {
}
}

fn scalar_function_type_from_str(
ctx: &SessionContext,
name: &str,
) -> Result<ScalarFunctionType> {
let s = ctx.state();
let name = substrait_fun_name(name);

if let Some(func) = s.scalar_functions().get(name) {
return Ok(ScalarFunctionType::Udf(func.to_owned()));
}

if let Ok(op) = name_to_op(name) {
return Ok(ScalarFunctionType::Op(op));
}

if let Some(builder) = BuiltinExprBuilder::try_from_name(name) {
return Ok(ScalarFunctionType::Expr(builder));
}

not_impl_err!("Unsupported function name: {name:?}")
}

pub fn substrait_fun_name(name: &str) -> &str {
let name = match name.rsplit_once(':') {
// Since 0.32.0, Substrait requires the function names to be in a compound format
Expand Down Expand Up @@ -972,7 +943,7 @@ pub async fn from_substrait_rex_vec(
}

/// Convert Substrait FunctionArguments to DataFusion Exprs
pub async fn from_substriat_func_args(
pub async fn from_substrait_func_args(
ctx: &SessionContext,
arguments: &Vec<FunctionArgument>,
input_schema: &DFSchema,
Expand All @@ -984,9 +955,7 @@ pub async fn from_substriat_func_args(
Some(ArgType::Value(e)) => {
from_substrait_rex(ctx, e, input_schema, extensions).await
}
_ => {
not_impl_err!("Aggregated function argument non-Value type not supported")
}
_ => not_impl_err!("Function argument non-Value type not supported"),
};
args.push(arg_expr?.as_ref().clone());
}
Expand All @@ -1003,33 +972,25 @@ pub async fn from_substrait_agg_func(
order_by: Option<Vec<Expr>>,
distinct: bool,
) -> Result<Arc<Expr>> {
let mut args: Vec<Expr> = vec![];
for arg in &f.arguments {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
from_substrait_rex(ctx, e, input_schema, extensions).await
}
_ => {
not_impl_err!("Aggregated function argument non-Value type not supported")
}
};
args.push(arg_expr?.as_ref().clone());
}
let args =
from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?;

let Some(function_name) = extensions.get(&f.function_reference) else {
return plan_err!(
"Aggregate function not registered: function anchor = {:?}",
f.function_reference
);
};
// function_name.split(':').next().unwrap_or(function_name);

let function_name = substrait_fun_name((**function_name).as_str());
// try udaf first, then built-in aggr fn.
if let Ok(fun) = ctx.udaf(function_name) {
// deal with situation that count(*) got no arguments
if fun.name() == "count" && args.is_empty() {
args.push(Expr::Literal(ScalarValue::Int64(Some(1))));
}
let args = if fun.name() == "count" && args.is_empty() {
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
} else {
args
};

Ok(Arc::new(Expr::AggregateFunction(
expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None),
Expand All @@ -1041,7 +1002,7 @@ pub async fn from_substrait_agg_func(
)))
} else {
not_impl_err!(
"Aggregated function {} is not supported: function anchor = {:?}",
"Aggregate function {} is not supported: function anchor = {:?}",
function_name,
f.function_reference
)
Expand Down Expand Up @@ -1145,84 +1106,40 @@ pub async fn from_substrait_rex(
})))
}
Some(RexType::ScalarFunction(f)) => {
let fn_name = extensions.get(&f.function_reference).ok_or_else(|| {
DataFusionError::NotImplemented(format!(
"Aggregated function not found: function reference = {:?}",
let Some(fn_name) = extensions.get(&f.function_reference) else {
return plan_err!(
"Scalar function not found: function reference = {:?}",
f.function_reference
))
})?;

// Convert function arguments from Substrait to DataFusion
async fn decode_arguments(
ctx: &SessionContext,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
function_args: &[FunctionArgument],
) -> Result<Vec<Expr>> {
let mut args = Vec::with_capacity(function_args.len());
for arg in function_args {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
from_substrait_rex(ctx, e, input_schema, extensions).await
}
_ => not_impl_err!(
"Aggregated function argument non-Value type not supported"
),
}?;
args.push(arg_expr.as_ref().clone());
}
Ok(args)
}
);
};
let fn_name = substrait_fun_name(fn_name);

let fn_type = scalar_function_type_from_str(ctx, fn_name)?;
match fn_type {
ScalarFunctionType::Udf(fun) => {
let args = decode_arguments(
ctx,
input_schema,
extensions,
f.arguments.as_slice(),
)
let args =
from_substrait_func_args(ctx, &f.arguments, input_schema, extensions)
.await?;
Ok(Arc::new(Expr::ScalarFunction(
expr::ScalarFunction::new_udf(fun, args),
)))
}
ScalarFunctionType::Op(op) => {
if f.arguments.len() != 2 {
return not_impl_err!(
"Expect two arguments for binary operator {op:?}"
);
}
let lhs = &f.arguments[0].arg_type;
let rhs = &f.arguments[1].arg_type;

match (lhs, rhs) {
(Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(
from_substrait_rex(ctx, l, input_schema, extensions)
.await?
.as_ref()
.clone(),
),
op,
right: Box::new(
from_substrait_rex(ctx, r, input_schema, extensions)
.await?
.as_ref()
.clone(),
),
})))
}
(l, r) => not_impl_err!(
"Invalid arguments for binary expression: {l:?} and {r:?}"
),
}
}
ScalarFunctionType::Expr(builder) => {
builder.build(ctx, f, input_schema, extensions).await

// try to first match the requested function into registered udfs, then built-in ops
// and finally built-in expressions
if let Some(func) = ctx.state().scalar_functions().get(fn_name) {
Ok(Arc::new(Expr::ScalarFunction(
expr::ScalarFunction::new_udf(func.to_owned(), args),
)))
} else if let Ok(op) = name_to_op(fn_name) {
if args.len() != 2 {
return not_impl_err!(
"Expect two arguments for binary operator {op:?}"
);
}

Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(args[0].to_owned()),
op,
right: Box::new(args[1].to_owned()),
})))
} else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) {
builder.build(ctx, f, input_schema, extensions).await
} else {
not_impl_err!("Unsupported function name: {fn_name:?}")
}
}
Some(RexType::Literal(lit)) => {
Expand All @@ -1247,36 +1164,50 @@ pub async fn from_substrait_rex(
None => substrait_err!("Cast expression without output type is not allowed"),
},
Some(RexType::WindowFunction(window)) => {
let fun = match extensions.get(&window.function_reference) {
Some(function_name) => {
// check udaf
match ctx.udaf(function_name) {
Ok(udaf) => {
Ok(Some(WindowFunctionDefinition::AggregateUDF(udaf)))
}
Err(_) => Ok(find_df_window_func(function_name)),
}
}
None => not_impl_err!(
"Window function not found: function anchor = {:?}",
&window.function_reference
),
let Some(fn_name) = extensions.get(&window.function_reference) else {
return plan_err!(
"Window function not found: function reference = {:?}",
window.function_reference
);
};
let fn_name = substrait_fun_name(fn_name);

// check udaf first, then built-in functions
let fun = match ctx.udaf(fn_name) {
Ok(udaf) => Ok(WindowFunctionDefinition::AggregateUDF(udaf)),
Err(_) => find_df_window_func(fn_name).ok_or_else(|| {
not_impl_datafusion_err!(
"Window function {} is not supported: function anchor = {:?}",
fn_name,
window.function_reference
)
}),
}?;

let order_by =
from_substrait_sorts(ctx, &window.sorts, input_schema, extensions)
.await?;
// Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units
// If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary
// If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row
// TODO: Consider the cases where window frame is specified in query and is different from default
let units = if order_by.is_empty() {
WindowFrameUnits::Rows
} else {
WindowFrameUnits::Range
};

let bound_units =
match BoundsType::try_from(window.bounds_type).map_err(|e| {
plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type)
})? {
BoundsType::Rows => WindowFrameUnits::Rows,
BoundsType::Range => WindowFrameUnits::Range,
BoundsType::Unspecified => {
// If the plan does not specify the bounds type, then we use a simple logic to determine the units
// If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary
// If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row
if order_by.is_empty() {
WindowFrameUnits::Rows
} else {
WindowFrameUnits::Range
}
}
};
Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction {
fun: fun?.unwrap(),
args: from_substriat_func_args(
fun,
args: from_substrait_func_args(
ctx,
&window.arguments,
input_schema,
Expand All @@ -1292,7 +1223,7 @@ pub async fn from_substrait_rex(
.await?,
order_by,
window_frame: datafusion::logical_expr::WindowFrame::new_bounds(
units,
bound_units,
from_substrait_bound(&window.lower_bound, true)?,
from_substrait_bound(&window.upper_bound, false)?,
),
Expand Down
Loading

0 comments on commit 1172975

Please sign in to comment.