diff --git a/src/flow/src/adapter.rs b/src/flow/src/adapter.rs index 60336f548f24..3940cacc884c 100644 --- a/src/flow/src/adapter.rs +++ b/src/flow/src/adapter.rs @@ -49,10 +49,10 @@ use crate::adapter::table_source::TableSource; use crate::adapter::util::column_schemas_to_proto; use crate::adapter::worker::{create_worker, Worker, WorkerHandle}; use crate::compute::ErrCollector; +use crate::df_optimizer::sql_to_flow_plan; use crate::error::{ExternalSnafu, InternalSnafu, TableNotFoundSnafu, UnexpectedSnafu}; use crate::expr::GlobalId; use crate::repr::{self, DiffRow, Row, BATCH_SIZE}; -use crate::transform::sql_to_flow_plan; mod flownode_impl; mod parse_expr; diff --git a/src/flow/src/df_optimizer.rs b/src/flow/src/df_optimizer.rs new file mode 100644 index 000000000000..c08680bce63e --- /dev/null +++ b/src/flow/src/df_optimizer.rs @@ -0,0 +1,595 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Datafusion optimizer for flow plan + +#![warn(unused)] + +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use common_error::ext::BoxedError; +use datafusion::config::ConfigOptions; +use datafusion::error::DataFusionError; +use datafusion::optimizer::analyzer::type_coercion::TypeCoercion; +use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate; +use datafusion::optimizer::optimize_projections::OptimizeProjections; +use datafusion::optimizer::simplify_expressions::SimplifyExpressions; +use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; +use datafusion::optimizer::utils::NamePreserver; +use datafusion::optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerContext}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, +}; +use datafusion_common::{Column, DFSchema, ScalarValue}; +use datafusion_expr::aggregate_function::AggregateFunction; +use datafusion_expr::expr::AggregateFunctionDefinition; +use datafusion_expr::utils::merge_schema; +use datafusion_expr::{ + BinaryExpr, Expr, Operator, Projection, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use query::parser::QueryLanguageParser; +use query::plan::LogicalPlan; +use query::query_engine::DefaultSerializer; +use query::QueryEngine; +use snafu::ResultExt; +/// note here we are using the `substrait_proto_df` crate from the `substrait` module and +/// rename it to `substrait_proto` +use substrait::DFLogicalSubstraitConvertor; + +use crate::adapter::FlownodeContext; +use crate::error::{DatafusionSnafu, Error, ExternalSnafu, UnexpectedSnafu}; +use crate::plan::TypedPlan; + +// TODO(discord9): use `Analyzer` to manage rules if more `AnalyzerRule` is needed +pub async fn apply_df_optimizer( + plan: datafusion_expr::LogicalPlan, +) -> Result { + let cfg = ConfigOptions::new(); + let analyzer = Analyzer::with_rules(vec![ + Arc::new(AvgExpandRule::new()), + Arc::new(TumbleExpandRule::new()), + Arc::new(CheckGroupByRule::new()), + Arc::new(TypeCoercion::new()), + ]); + let plan = analyzer + .execute_and_check(plan, &cfg, |p, r| { + dbg!("rule applied", r.name(), p); + }) + .context(DatafusionSnafu { + context: "Fail to apply analyzer", + })?; + + let ctx = OptimizerContext::new(); + let optimizer = Optimizer::with_rules(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(CommonSubexprEliminate::new()), + Arc::new(SimplifyExpressions::new()), + Arc::new(UnwrapCastInComparison::new()), + ]); + let plan = optimizer + .optimize(plan, &ctx, |_, _| {}) + .context(DatafusionSnafu { + context: "Fail to apply optimizer", + })?; + + Ok(plan) +} + +/// To reuse existing code for parse sql, the sql is first parsed into a datafusion logical plan, +/// then to a substrait plan, and finally to a flow plan. +pub async fn sql_to_flow_plan( + ctx: &mut FlownodeContext, + engine: &Arc, + sql: &str, +) -> Result { + let query_ctx = ctx.query_context.clone().ok_or_else(|| { + UnexpectedSnafu { + reason: "Query context is missing", + } + .build() + })?; + let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx) + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + let plan = engine + .planner() + .plan(stmt, query_ctx) + .await + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + let LogicalPlan::DfPlan(plan) = plan; + + let opted_plan = apply_df_optimizer(plan).await?; + + // TODO(discord9): add df optimization + let sub_plan = DFLogicalSubstraitConvertor {} + .to_sub_plan(&opted_plan, DefaultSerializer) + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + + let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?; + + Ok(flow_plan) +} + +struct AvgExpandRule {} + +impl AvgExpandRule { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for AvgExpandRule { + fn analyze( + &self, + plan: datafusion_expr::LogicalPlan, + _config: &ConfigOptions, + ) -> datafusion_common::Result { + let transformed = plan + .transform_up_with_subqueries(expand_avg_analyzer)? + .data + .transform_down_with_subqueries(put_aggr_to_proj_analyzer)? + .data; + Ok(transformed) + } + + fn name(&self) -> &str { + "avg_expand" + } +} + +/// lift aggr's composite aggr_expr to outer proj, and leave aggr only with simple direct aggr expr +/// i.e. +/// ```ignore +/// proj: avg(x) +/// -- aggr: [sum(x)/count(x) as avg(x)] +/// ``` +/// becomes: +/// ```ignore +/// proj: sum(x)/count(x) as avg(x) +/// -- aggr: [sum(x), count(x)] +/// ``` +fn put_aggr_to_proj_analyzer( + plan: datafusion_expr::LogicalPlan, +) -> Result, DataFusionError> { + if let datafusion_expr::LogicalPlan::Projection(proj) = &plan { + if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() { + let mut replace_old_proj_exprs = HashMap::new(); + let mut expanded_aggr_exprs = vec![]; + for aggr_expr in &aggr.aggr_expr { + let mut is_composite = false; + if let Expr::AggregateFunction(_) = &aggr_expr { + expanded_aggr_exprs.push(aggr_expr.clone()); + } else { + let old_name = aggr_expr.name_for_alias()?; + let new_proj_expr = aggr_expr + .clone() + .transform(|ch| { + if let Expr::AggregateFunction(_) = &ch { + is_composite = true; + expanded_aggr_exprs.push(ch.clone()); + Ok(Transformed::yes(Expr::Column(Column::from_qualified_name( + ch.name_for_alias()?, + )))) + } else { + Ok(Transformed::no(ch)) + } + })? + .data; + replace_old_proj_exprs.insert(old_name, new_proj_expr); + } + } + + if expanded_aggr_exprs.len() > aggr.aggr_expr.len() { + let mut aggr = aggr.clone(); + aggr.aggr_expr = expanded_aggr_exprs; + let mut aggr_plan = datafusion_expr::LogicalPlan::Aggregate(aggr); + // important to recompute schema after changing aggr_expr + aggr_plan = aggr_plan.recompute_schema()?; + + // reconstruct proj with new proj_exprs + let mut new_proj_exprs = proj.expr.clone(); + for proj_expr in new_proj_exprs.iter_mut() { + if let Some(new_proj_expr) = + replace_old_proj_exprs.get(&proj_expr.name_for_alias()?) + { + *proj_expr = new_proj_expr.clone(); + } + *proj_expr = proj_expr + .clone() + .transform(|expr| { + if let Some(new_expr) = + replace_old_proj_exprs.get(&expr.name_for_alias()?) + { + Ok(Transformed::yes(new_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + } + let proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new( + new_proj_exprs, + Arc::new(aggr_plan), + )?); + return Ok(Transformed::yes(proj)); + } + } + } + Ok(Transformed::no(plan)) +} + +/// expand `avg()` function into `cast(sum(() AS f64)/count(()` +fn expand_avg_analyzer( + plan: datafusion_expr::LogicalPlan, +) -> Result, DataFusionError> { + let mut schema = merge_schema(plan.inputs()); + + if let datafusion_expr::LogicalPlan::TableScan(ts) = &plan { + let source_schema = + DFSchema::try_from_qualified_schema(ts.table_name.clone(), &ts.source.schema())?; + schema.merge(&source_schema); + } + + let mut expr_rewrite = ExpandAvgRewriter::new(&schema); + + let name_preserver = NamePreserver::new(&plan); + // apply coercion rewrite all expressions in the plan individually + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + expr.rewrite(&mut expr_rewrite)? + .map_data(|expr| original_name.restore(expr)) + })? + .map_data(|plan| plan.recompute_schema()) +} + +/// rewrite `avg()` function into `CASE WHEN count() !=0 THEN cast(sum(() AS avg_return_type)/count(() ELSE 0` +/// +/// TODO(discord9): support avg return type decimal128 +/// +/// see impl details at https://github.com/apache/datafusion/blob/4ad4f90d86c57226a4e0fb1f79dfaaf0d404c273/datafusion/expr/src/type_coercion/aggregates.rs#L457-L462 +pub(crate) struct ExpandAvgRewriter<'a> { + /// schema of the plan + #[allow(unused)] + pub(crate) schema: &'a DFSchema, +} + +impl<'a> ExpandAvgRewriter<'a> { + fn new(schema: &'a DFSchema) -> Self { + Self { schema } + } +} + +impl<'a> TreeNodeRewriter for ExpandAvgRewriter<'a> { + type Node = Expr; + + fn f_up(&mut self, expr: Expr) -> Result, DataFusionError> { + if let Expr::AggregateFunction(aggr_func) = &expr { + if let AggregateFunctionDefinition::BuiltIn(AggregateFunction::Avg) = + &aggr_func.func_def + { + let sum_expr = { + let mut tmp = aggr_func.clone(); + tmp.func_def = AggregateFunctionDefinition::BuiltIn(AggregateFunction::Sum); + Expr::AggregateFunction(tmp) + }; + let sum_cast = { + let mut tmp = sum_expr.clone(); + tmp = Expr::Cast(datafusion_expr::Cast { + expr: Box::new(tmp), + data_type: arrow_schema::DataType::Float64, + }); + tmp + }; + + let count_expr = { + let mut tmp = aggr_func.clone(); + tmp.func_def = AggregateFunctionDefinition::BuiltIn(AggregateFunction::Count); + + Expr::AggregateFunction(tmp) + }; + let count_expr_ref = + Expr::Column(Column::from_qualified_name(count_expr.name_for_alias()?)); + + let div = + BinaryExpr::new(Box::new(sum_cast), Operator::Divide, Box::new(count_expr)); + let div_expr = Box::new(Expr::BinaryExpr(div)); + + let zero = Box::new(Expr::Literal(ScalarValue::Int64(Some(0)))); + let not_zero = + BinaryExpr::new(Box::new(count_expr_ref), Operator::NotEq, zero.clone()); + let not_zero = Box::new(Expr::BinaryExpr(not_zero)); + let null = Box::new(Expr::Literal(ScalarValue::Null)); + + let case_when = + datafusion_expr::Case::new(None, vec![(not_zero, div_expr)], Some(null)); + let case_when_expr = Expr::Case(case_when); + + return Ok(Transformed::yes(case_when_expr)); + } + } + + Ok(Transformed::no(expr)) + } +} + +/// expand tumble in aggr expr to tumble_start and tumble_end with column name like `window_start` +struct TumbleExpandRule {} + +impl TumbleExpandRule { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for TumbleExpandRule { + fn analyze( + &self, + plan: datafusion_expr::LogicalPlan, + _config: &ConfigOptions, + ) -> datafusion_common::Result { + let transformed = plan + .transform_up_with_subqueries(expand_tumble_analyzer)? + .data; + Ok(transformed) + } + + fn name(&self) -> &str { + "tumble_expand" + } +} + +/// expand `tumble` in aggr expr to `tumble_start` and `tumble_end`, also expand related alias and column ref +/// +/// will add `tumble_start` and `tumble_end` to outer projection if not exist before +fn expand_tumble_analyzer( + plan: datafusion_expr::LogicalPlan, +) -> Result, DataFusionError> { + if let datafusion_expr::LogicalPlan::Projection(proj) = &plan { + if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() { + let mut new_group_expr = vec![]; + let mut alias_to_expand = HashMap::new(); + for expr in aggr.group_expr.iter() { + match expr { + datafusion_expr::Expr::ScalarFunction(func) => { + if func.name() == "tumble" { + let tumble_start = TumbleExpand::new("tumble_start"); + let tumble_start = datafusion_expr::expr::ScalarFunction::new_udf( + Arc::new(tumble_start.into()), + func.args.clone(), + ); + let tumble_start = datafusion_expr::Expr::ScalarFunction(tumble_start); + let start_col_name = tumble_start.name_for_alias()?; + new_group_expr.push(tumble_start); + + let tumble_end = TumbleExpand::new("tumble_end"); + let tumble_end = datafusion_expr::expr::ScalarFunction::new_udf( + Arc::new(tumble_end.into()), + func.args.clone(), + ); + let tumble_end = datafusion_expr::Expr::ScalarFunction(tumble_end); + let end_col_name = tumble_end.name_for_alias()?; + new_group_expr.push(tumble_end); + + alias_to_expand + .insert(expr.name_for_alias()?, (start_col_name, end_col_name)); + } + } + _ => new_group_expr.push(expr.clone()), + } + } + + let mut new_aggr = aggr.clone(); + new_aggr.group_expr = new_group_expr; + let new_aggr = datafusion_expr::LogicalPlan::Aggregate(new_aggr).recompute_schema()?; + // replace alias in projection if needed, and add new column ref if necessary + let mut new_proj_expr = vec![]; + let mut have_expanded = false; + + for proj_expr in proj.expr.iter() { + if let Some((start_col_name, end_col_name)) = + alias_to_expand.get(&proj_expr.name_for_alias()?) + { + let start_col = Column::from_qualified_name(start_col_name); + let end_col = Column::from_qualified_name(end_col_name); + new_proj_expr.push(datafusion_expr::Expr::Column(start_col)); + new_proj_expr.push(datafusion_expr::Expr::Column(end_col)); + have_expanded = true; + } else { + new_proj_expr.push(proj_expr.clone()); + } + } + + // append to end of projection if not exist + if !have_expanded { + for (start_col_name, end_col_name) in alias_to_expand.values() { + let start_col = Column::from_qualified_name(start_col_name); + let end_col = Column::from_qualified_name(end_col_name); + new_proj_expr + .push(datafusion_expr::Expr::Column(start_col).alias("window_start")); + new_proj_expr.push(datafusion_expr::Expr::Column(end_col).alias("window_end")); + } + } + + let new_proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new( + new_proj_expr, + Arc::new(new_aggr), + )?); + return Ok(Transformed::yes(new_proj)); + } + } + + Ok(Transformed::no(plan)) +} + +#[derive(Debug)] +pub struct TumbleExpand { + signature: Signature, + name: String, +} + +impl TumbleExpand { + pub fn new(name: &str) -> Self { + Self { + signature: Signature::new(TypeSignature::UserDefined, Volatility::Immutable), + name: name.to_string(), + } + } +} + +impl ScalarUDFImpl for TumbleExpand { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + /// elide the signature for now + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types( + &self, + arg_types: &[arrow_schema::DataType], + ) -> datafusion_common::Result> { + match (arg_types.first(), arg_types.get(1), arg_types.get(2)) { + (Some(ts), Some(window), opt) => { + use arrow_schema::DataType::*; + if !matches!(ts, Date32 | Date64 | Timestamp(_, _)) { + return Err(DataFusionError::Plan( + format!("Expect timestamp column as first arg for tumble_start, found {:?}", ts) + )); + } + if !matches!(window, Utf8 | Interval(_)) { + return Err(DataFusionError::Plan( + format!("Expect second arg for window size's type being interval for tumble_start, found {:?}", window), + )); + } + + if let Some(start_time) = opt{ + if !matches!(start_time, Utf8 | Date32 | Date64 | Timestamp(_, _)){ + return Err(DataFusionError::Plan( + format!("Expect start_time to either be date, timestampe or string, found {:?}", start_time) + )); + } + } + + Ok(arg_types.to_vec()) + } + _ => Err(DataFusionError::Plan( + "Expect tumble function have at least two arg(timestamp column and window size) and a third optional arg for starting time".to_string(), + )), + } + } + + fn return_type( + &self, + arg_types: &[arrow_schema::DataType], + ) -> Result { + arg_types.first().cloned().ok_or_else(|| { + DataFusionError::Plan( + "Expect tumble function have at least two arg(timestamp column and window size)" + .to_string(), + ) + }) + } + + fn invoke( + &self, + _args: &[datafusion_expr::ColumnarValue], + ) -> Result { + Err(DataFusionError::Plan( + "This function should not be executed by datafusion".to_string(), + )) + } +} + +struct CheckGroupByRule {} + +impl CheckGroupByRule { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for CheckGroupByRule { + fn analyze( + &self, + plan: datafusion_expr::LogicalPlan, + _config: &ConfigOptions, + ) -> datafusion_common::Result { + let transformed = plan + .transform_up_with_subqueries(check_group_by_analyzer)? + .data; + Ok(transformed) + } + + fn name(&self) -> &str { + "check_groupby" + } +} + +/// make sure everything in group by's expr is in select +fn check_group_by_analyzer( + plan: datafusion_expr::LogicalPlan, +) -> Result, DataFusionError> { + if let datafusion_expr::LogicalPlan::Projection(proj) = &plan { + if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() { + let mut found_column_used = FindColumn::new(); + proj.expr + .iter() + .map(|i| i.visit(&mut found_column_used)) + .count(); + for expr in aggr.group_expr.iter() { + if !found_column_used + .names_for_alias + .contains(&expr.name_for_alias()?) + { + return Err(DataFusionError::Plan(format!("Expect {} expr in group by also exist in select list, but select list only contain {:?}",expr.name_for_alias()?, found_column_used.names_for_alias))); + } + } + } + } + + Ok(Transformed::no(plan)) +} + +#[derive(Debug, Default)] +struct FindColumn { + names_for_alias: HashSet, +} + +impl FindColumn { + fn new() -> Self { + Default::default() + } +} + +impl TreeNodeVisitor<'_> for FindColumn { + type Node = datafusion_expr::Expr; + fn f_down( + &mut self, + node: &datafusion_expr::Expr, + ) -> Result { + if let datafusion_expr::Expr::Column(_) = node { + self.names_for_alias.insert(node.name_for_alias()?); + } + Ok(TreeNodeRecursion::Continue) + } +} diff --git a/src/flow/src/expr/func.rs b/src/flow/src/expr/func.rs index 39b469207169..54971d4a7580 100644 --- a/src/flow/src/expr/func.rs +++ b/src/flow/src/expr/func.rs @@ -221,6 +221,57 @@ impl UnaryFunc { } } + pub fn from_tumble_func(name: &str, args: &[TypedExpr]) -> Result<(Self, TypedExpr), Error> { + match name { + "tumble_start" | "tumble_end" => { + let ts = args.first().context(InvalidQuerySnafu { + reason: "Tumble window function requires a timestamp argument", + })?; + let window_size = args + .get(1) + .and_then(|expr| expr.expr.as_literal()) + .context(InvalidQuerySnafu { + reason: "Tumble window function requires a window size argument" + })?.as_string() // TODO(discord9): since df to substrait convertor does not support interval type yet, we need to take a string and cast it to interval instead + .map(|s|cast(Value::from(s), &ConcreteDataType::interval_month_day_nano_datatype())).transpose().map_err(BoxedError::new).context( + ExternalSnafu + )?.and_then(|v|v.as_interval()) + .with_context(||InvalidQuerySnafu { + reason: format!("Tumble window function requires window size argument to be a string describe a interval, found {:?}", args.get(1)) + })?; + let start_time = match args.get(2) { + Some(start_time) => start_time.expr.as_literal(), + None => None, + } + .map(|s| cast(s.clone(), &ConcreteDataType::datetime_datatype())).transpose().map_err(BoxedError::new).context(ExternalSnafu)?.map(|v|v.as_datetime().with_context( + ||InvalidQuerySnafu { + reason: format!("Tumble window function requires start time argument to be a datetime describe in string, found {:?}", args.get(2)) + } + )).transpose()?; + if name == "tumble_start" { + Ok(( + Self::TumbleWindowFloor { + window_size, + start_time, + }, + ts.clone(), + )) + } else if name == "tumble_end" { + Ok(( + Self::TumbleWindowCeiling { + window_size, + start_time, + }, + ts.clone(), + )) + } else { + unreachable!() + } + } + _ => todo!(), + } + } + /// Evaluate the function with given values and expression /// /// # Arguments @@ -571,8 +622,8 @@ impl BinaryFunc { t1 == t2, InvalidQuerySnafu { reason: format!( - "Binary function {:?} requires both arguments to have the same type", - generic + "Binary function {:?} requires both arguments to have the same type, left={:?}, right={:?}", + generic,t1,t2 ), } ); diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 8a3290a932f1..3a0f57b40de4 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -58,77 +58,6 @@ impl TypedExpr { } } -impl TypedExpr { - /// expand multi-value expression to multiple expressions with new indices - /// - /// Currently it just mean expand `TumbleWindow` to `TumbleWindowFloor` and `TumbleWindowCeiling` - /// - /// TODO(discord9): test if nested reduce combine with df scalar function would cause problem - pub fn expand_multi_value( - input_typ: &RelationType, - exprs: &[TypedExpr], - ) -> Result, Error> { - // old indices in mfp, expanded expr - let mut ret = vec![]; - let input_arity = input_typ.column_types.len(); - for (old_idx, expr) in exprs.iter().enumerate() { - if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::TumbleWindow { - ts, - window_size, - start_time, - }) = &expr.expr - { - let floor = UnaryFunc::TumbleWindowFloor { - window_size: *window_size, - start_time: *start_time, - }; - let ceil = UnaryFunc::TumbleWindowCeiling { - window_size: *window_size, - start_time: *start_time, - }; - let floor = ScalarExpr::CallUnary { - func: floor, - expr: Box::new(ts.expr.clone()), - } - .with_type(ts.typ.clone()); - ret.push((None, floor)); - - let ceil = ScalarExpr::CallUnary { - func: ceil, - expr: Box::new(ts.expr.clone()), - } - .with_type(ts.typ.clone()); - ret.push((None, ceil)); - } else { - ret.push((Some(input_arity + old_idx), expr.clone())) - } - } - - // get shuffled index(old_idx -> new_idx) - // note index is offset by input_arity because mfp is designed to be first include input columns then intermediate columns - let shuffle = ret - .iter() - .map(|(old_idx, _)| *old_idx) // [Option] - .enumerate() - .map(|(new, old)| (old, new + input_arity)) - .flat_map(|(old, new)| old.map(|o| (o, new))) - .chain((0..input_arity).map(|i| (i, i))) // also remember to chain the input columns as not changed - .collect::>(); - - // shuffle expr's index - let exprs = ret - .into_iter() - .map(|(_, mut expr)| { - // invariant: it is expect that no expr will try to refer the column being expanded - expr.expr.permute_map(&shuffle)?; - Ok(expr) - }) - .collect::, _>>()?; - - Ok(exprs) - } -} - /// A scalar expression, which can be evaluated to a value. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ScalarExpr { @@ -379,6 +308,13 @@ impl ScalarExpr { } impl ScalarExpr { + pub fn cast(self, typ: ConcreteDataType) -> Self { + ScalarExpr::CallUnary { + func: UnaryFunc::Cast(typ), + expr: Box::new(self), + } + } + /// apply optimization to the expression, like flatten variadic function pub fn optimize(&mut self) { self.flatten_varidic_fn(); diff --git a/src/flow/src/expr/signature.rs b/src/flow/src/expr/signature.rs index d61a60dea5e2..82506d1293c9 100644 --- a/src/flow/src/expr/signature.rs +++ b/src/flow/src/expr/signature.rs @@ -19,6 +19,8 @@ use serde::{Deserialize, Serialize}; use smallvec::SmallVec; /// Function signature +/// +/// TODO(discord9): use `common_query::signature::Signature` crate #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)] pub struct Signature { /// the input types, usually not great than two input arg diff --git a/src/flow/src/lib.rs b/src/flow/src/lib.rs index d01e5ea28346..0539ac5b18ba 100644 --- a/src/flow/src/lib.rs +++ b/src/flow/src/lib.rs @@ -25,6 +25,7 @@ // allow unused for now because it should be use later mod adapter; mod compute; +mod df_optimizer; pub mod error; mod expr; pub mod heartbeat; diff --git a/src/flow/src/plan.rs b/src/flow/src/plan.rs index c31ddb652e3b..6fc3b093c1c8 100644 --- a/src/flow/src/plan.rs +++ b/src/flow/src/plan.rs @@ -121,6 +121,8 @@ impl TypedPlan { /// TODO(discord9): support `TableFunc`(by define FlatMap that map 1 to n) /// Plan describe how to transform data in dataflow +/// +/// This can be considered as a physical plan in dataflow, which describe how to transform data in #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub enum Plan { /// A constant collection of rows. diff --git a/src/flow/src/repr/relation.rs b/src/flow/src/repr/relation.rs index e470ad9dbdbf..2fd8e6c8734c 100644 --- a/src/flow/src/repr/relation.rs +++ b/src/flow/src/repr/relation.rs @@ -376,10 +376,8 @@ impl RelationDesc { .collect(); let arrow_schema = arrow_schema::Schema::new(fields); - DFSchema::try_from(arrow_schema.clone()).context({ - DatafusionSnafu { - context: format!("Error when converting to DFSchema: {:?}", arrow_schema), - } + DFSchema::try_from(arrow_schema.clone()).with_context(|_e| DatafusionSnafu { + context: format!("Error when converting to DFSchema: {:?}", arrow_schema), }) } diff --git a/src/flow/src/transform.rs b/src/flow/src/transform.rs index f8075b5dc221..e029bd68d1f7 100644 --- a/src/flow/src/transform.rs +++ b/src/flow/src/transform.rs @@ -18,9 +18,12 @@ use std::sync::Arc; use bytes::buf::IntoIter; use common_error::ext::BoxedError; +use common_query::error::InvalidFuncArgsSnafu; use common_telemetry::info; +use datafusion::config::ConfigOptions; +use datafusion::optimizer::analyzer::type_coercion::TypeCoercion; use datafusion::optimizer::simplify_expressions::SimplifyExpressions; -use datafusion::optimizer::{OptimizerContext, OptimizerRule}; +use datafusion::optimizer::{AnalyzerRule, OptimizerContext, OptimizerRule}; use datatypes::data_type::ConcreteDataType as CDT; use literal::{from_substrait_literal, from_substrait_type}; use prost::Message; @@ -114,68 +117,39 @@ impl FunctionExtensions { } } -/// To reuse existing code for parse sql, the sql is first parsed into a datafusion logical plan, -/// then to a substrait plan, and finally to a flow plan. -pub async fn sql_to_flow_plan( - ctx: &mut FlownodeContext, - engine: &Arc, - sql: &str, -) -> Result { - let query_ctx = ctx.query_context.clone().ok_or_else(|| { - UnexpectedSnafu { - reason: "Query context is missing", - } - .build() - })?; - let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx) - .map_err(BoxedError::new) - .context(ExternalSnafu)?; - let plan = engine - .planner() - .plan(stmt, query_ctx) - .await - .map_err(BoxedError::new) - .context(ExternalSnafu)?; - let LogicalPlan::DfPlan(plan) = plan; - let plan = SimplifyExpressions::new() - .rewrite(plan, &OptimizerContext::default()) - .context(DatafusionSnafu { - context: "Fail to apply `SimplifyExpressions` optimization", - })? - .data; - let sub_plan = DFLogicalSubstraitConvertor {} - .to_sub_plan(&plan, DefaultSerializer) - .map_err(BoxedError::new) - .context(ExternalSnafu)?; - - let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?; - - Ok(flow_plan) -} - /// register flow-specific functions to the query engine pub fn register_function_to_query_engine(engine: &Arc) { - engine.register_function(Arc::new(TumbleFunction {})); + engine.register_function(Arc::new(TumbleFunction::new("tumble"))); + engine.register_function(Arc::new(TumbleFunction::new("tumble_start"))); + engine.register_function(Arc::new(TumbleFunction::new("tumble_end"))); } #[derive(Debug)] -pub struct TumbleFunction {} +pub struct TumbleFunction { + name: String, +} -const TUMBLE_NAME: &str = "tumble"; +impl TumbleFunction { + fn new(name: &str) -> Self { + Self { + name: name.to_string(), + } + } +} impl std::fmt::Display for TumbleFunction { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", TUMBLE_NAME.to_ascii_uppercase()) + write!(f, "{}", self.name.to_ascii_uppercase()) } } impl common_function::function::Function for TumbleFunction { fn name(&self) -> &str { - TUMBLE_NAME + &self.name } fn return_type(&self, _input_types: &[CDT]) -> common_query::error::Result { - Ok(CDT::datetime_datatype()) + Ok(CDT::timestamp_millisecond_datatype()) } fn signature(&self) -> common_query::prelude::Signature { @@ -219,6 +193,7 @@ mod test { use super::*; use crate::adapter::node_context::IdToNameMap; + use crate::df_optimizer::apply_df_optimizer; use crate::repr::ColumnType; pub fn create_test_ctx() -> FlownodeContext { @@ -303,7 +278,7 @@ mod test { let factory = query::QueryEngineFactory::new(catalog_list, None, None, None, None, false); let engine = factory.query_engine(); - engine.register_function(Arc::new(TumbleFunction {})); + register_function_to_query_engine(&engine); assert_eq!("datafusion", engine.name()); engine @@ -318,6 +293,7 @@ mod test { .await .unwrap(); let LogicalPlan::DfPlan(plan) = plan; + let plan = apply_df_optimizer(plan).await.unwrap(); // encode then decode so to rely on the impl of conversion from logical plan to substrait plan let bytes = DFLogicalSubstraitConvertor {} @@ -326,4 +302,22 @@ mod test { proto::Plan::decode(bytes).unwrap() } + + /// TODO(discord9): add more illegal sql tests + #[tokio::test] + async fn test_missing_key_check() { + let engine = create_test_query_engine(); + let sql = "SELECT avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number"; + + let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap(); + let plan = engine + .planner() + .plan(stmt, QueryContext::arc()) + .await + .unwrap(); + let LogicalPlan::DfPlan(plan) = plan; + let plan = apply_df_optimizer(plan).await; + + assert!(plan.is_err()); + } } diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index 64ecc3eec506..d38bdeeefaea 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -92,10 +92,9 @@ impl AggregateExpr { measures: &[Measure], typ: &RelationDesc, extensions: &FunctionExtensions, - ) -> Result<(Vec, MapFilterProject), Error> { + ) -> Result, Error> { let _ = ctx; let mut all_aggr_exprs = vec![]; - let mut post_maps = vec![]; for m in measures { let filter = match m @@ -108,7 +107,7 @@ impl AggregateExpr { } .transpose()?; - let (aggr_expr, post_mfp) = match &m.measure { + let aggr_expr = match &m.measure { Some(f) => { let distinct = match f.invocation { _ if f.invocation == AggregationInvocation::Distinct as i32 => true, @@ -119,28 +118,17 @@ impl AggregateExpr { f, typ, extensions, &filter, // TODO(discord9): impl order_by &None, distinct, ) - .await + .await? } - None => not_impl_err!("Aggregate without aggregate function is not supported"), - }?; - // permute col index refer to the output of post_mfp, - // so to help construct a mfp at the end - let mut post_map = post_mfp.unwrap_or(ScalarExpr::Column(0)); - let cur_arity = all_aggr_exprs.len(); - let remap = (0..aggr_expr.len()).map(|i| i + cur_arity).collect_vec(); - post_map.permute(&remap)?; + None => { + return not_impl_err!("Aggregate without aggregate function is not supported") + } + }; all_aggr_exprs.extend(aggr_expr); - post_maps.push(post_map); } - let input_arity = all_aggr_exprs.len(); - let aggr_arity = post_maps.len(); - let post_mfp_final = MapFilterProject::new(all_aggr_exprs.len()) - .map(post_maps)? - .project(input_arity..input_arity + aggr_arity)?; - - Ok((all_aggr_exprs, post_mfp_final)) + Ok(all_aggr_exprs) } /// Convert AggregateFunction into Flow's AggregateExpr @@ -154,7 +142,7 @@ impl AggregateExpr { filter: &Option, order_by: &Option>, distinct: bool, - ) -> Result<(Vec, Option), Error> { + ) -> Result, Error> { // TODO(discord9): impl filter let _ = filter; let _ = order_by; @@ -185,7 +173,6 @@ impl AggregateExpr { .map(|s| s.to_lowercase()); match fn_name.as_ref().map(|s| s.as_ref()) { - Some(Self::AVG_NAME) => AggregateExpr::from_avg_aggr_func(arg), Some(function_name) => { let func = AggregateFunc::from_str_and_type( function_name, @@ -196,8 +183,7 @@ impl AggregateExpr { expr: arg.expr.clone(), distinct, }]; - let ret_mfp = None; - Ok((exprs, ret_mfp)) + Ok(exprs) } None => not_impl_err!( "Aggregated function not found: function anchor = {:?}", @@ -205,39 +191,6 @@ impl AggregateExpr { ), } } - const AVG_NAME: &'static str = "avg"; - /// convert `avg` function into `sum(x)/cast(count(x) as x_type)` - fn from_avg_aggr_func( - arg: &TypedExpr, - ) -> Result<(Vec, Option), Error> { - let arg_type = arg.typ.scalar_type.clone(); - let sum = AggregateExpr { - func: AggregateFunc::from_str_and_type("sum", Some(arg_type.clone()))?, - expr: arg.expr.clone(), - distinct: false, - }; - let sum_out_type = sum.func.signature().output.clone(); - let count = AggregateExpr { - func: AggregateFunc::Count, - expr: arg.expr.clone(), - distinct: false, - }; - let count_out_type = count.func.signature().output.clone(); - let avg_output = ScalarExpr::Column(0).call_binary( - ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(sum_out_type.clone())), - BinaryFunc::div(sum_out_type.clone())?, - ); - // make sure we wouldn't divide by zero - let zero = ScalarExpr::literal(count_out_type.default_value(), count_out_type.clone()); - let non_zero = ScalarExpr::If { - cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::NotEq)), - then: Box::new(avg_output), - els: Box::new(ScalarExpr::literal(Value::Null, sum_out_type.clone())), - }; - let ret_aggr_exprs = vec![sum, count]; - let ret_mfp = Some(non_zero); - Ok((ret_aggr_exprs, ret_mfp)) - } } impl KeyValPlan { @@ -323,21 +276,13 @@ impl TypedPlan { return not_impl_err!("Aggregate without an input is not supported"); }; - let group_exprs = { - let group_exprs = TypedExpr::from_substrait_agg_grouping( - ctx, - &agg.groupings, - &input.schema, - extensions, - ) - .await?; - - TypedExpr::expand_multi_value(&input.schema.typ, &group_exprs)? - }; + let group_exprs = + TypedExpr::from_substrait_agg_grouping(ctx, &agg.groupings, &input.schema, extensions) + .await?; let time_index = find_time_index_in_group_exprs(&group_exprs); - let (mut aggr_exprs, post_mfp) = AggregateExpr::from_substrait_agg_measures( + let mut aggr_exprs = AggregateExpr::from_substrait_agg_measures( ctx, &agg.measures, &input.schema, @@ -356,24 +301,13 @@ impl TypedPlan { let mut output_types = Vec::new(); // give best effort to get column name let mut output_names = Vec::new(); - // mark all auto added cols - let mut auto_cols = vec![]; + // first append group_expr as key, then aggr_expr as value - for (idx, expr) in group_exprs.iter().enumerate() { + for (_idx, expr) in group_exprs.iter().enumerate() { output_types.push(expr.typ.clone()); let col_name = match &expr.expr { - ScalarExpr::CallUnary { - func: UnaryFunc::TumbleWindowFloor { .. }, - .. - } => Some("window_start".to_string()), - ScalarExpr::CallUnary { - func: UnaryFunc::TumbleWindowCeiling { .. }, - .. - } => { - auto_cols.push(idx); - Some("window_end".to_string()) - } ScalarExpr::Column(col) => input.schema.get_name(*col).clone(), + // TODO(discord9): impl& use ScalarExpr.display_name, which recursively build expr's name _ => None, }; output_names.push(col_name) @@ -393,7 +327,6 @@ impl TypedPlan { RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec()) } .with_time_index(time_index) - .with_autos(&auto_cols) .into_named(output_names) }; @@ -431,39 +364,11 @@ impl TypedPlan { reduce_plan: ReducePlan::Accumulable(accum_plan), }; // FIX(discord9): deal with key first - if post_mfp.is_identity() { - Ok(TypedPlan { - schema: output_type, - plan, - }) - } else { - // make post_mfp map identical mapping of keys - let input = TypedPlan { - schema: output_type.clone(), - plan, - }; - let key_arity = group_exprs.len(); - let mut post_mfp = post_mfp; - let val_arity = post_mfp.input_arity; - // offset post_mfp's col ref by `key_arity` - let shuffle = BTreeMap::from_iter((0..val_arity).map(|v| (v, v + key_arity))); - let new_arity = key_arity + val_arity; - post_mfp.permute(shuffle, new_arity)?; - // add key projection to post mfp - let (m, f, p) = post_mfp.into_map_filter_project(); - let p = (0..key_arity).chain(p).collect_vec(); - let post_mfp = MapFilterProject::new(new_arity) - .map(m)? - .filter(f)? - .project(p)?; - Ok(TypedPlan { - schema: output_type.apply_mfp(&post_mfp.clone().into_safe())?, - plan: Plan::Mfp { - input: Box::new(input), - mfp: post_mfp, - }, - }) - } + + return Ok(TypedPlan { + schema: output_type, + plan, + }); } } @@ -479,18 +384,6 @@ mod test { use crate::plan::{Plan, TypedPlan}; use crate::repr::{self, ColumnType, RelationType}; use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait}; - /// TODO(discord9): add more illegal sql tests - #[tokio::test] - async fn test_missing_key_check() { - let engine = create_test_query_engine(); - let sql = "SELECT avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number"; - let plan = sql_to_substrait(engine.clone(), sql).await; - - let mut ctx = create_test_ctx(); - assert!(TypedPlan::from_substrait_plan(&mut ctx, &plan) - .await - .is_err()); - } #[tokio::test] async fn test_df_func_basic() { @@ -504,21 +397,20 @@ mod test { .unwrap(); let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end ]) .with_key(vec![2]) .with_time_index(Some(1)) - .with_autos(&[2]) .into_named(vec![ - None, + Some("SUM(abs(numbers_with_ts.number))".to_string()), Some("window_start".to_string()), Some("window_end".to_string()), ]), @@ -538,7 +430,9 @@ mod test { Some("number".to_string()), Some("ts".to_string()), ]), - ), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(2) @@ -573,7 +467,7 @@ mod test { df_scalar_fn: DfScalarFunction::try_from_raw_fn( RawDfScalarFn { f: BytesMut::from( - b"\x08\x01\"\x08\x1a\x06\x12\x04\n\x02\x12\0" + b"\x08\x02\"\x08\x1a\x06\x12\x04\n\x02\x12\0" .as_ref(), ), input_schema: RelationType::new(vec![ColumnType::new( @@ -583,9 +477,10 @@ mod test { .into_unnamed(), extensions: FunctionExtensions { anchor_to_name: BTreeMap::from([ - (0, "tumble".to_string()), - (1, "abs".to_string()), - (2, "sum".to_string()), + (0, "tumble_start".to_string()), + (1, "tumble_end".to_string()), + (2, "abs".to_string()), + (3, "sum".to_string()), ]), }, }, @@ -593,7 +488,8 @@ mod test { .await .unwrap(), exprs: vec![ScalarExpr::Column(0)], - }]) + } + .cast(CDT::uint64_datatype())]) .unwrap() .project(vec![2]) .unwrap() @@ -607,33 +503,27 @@ mod test { } .with_types( RelationType::new(vec![ - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end - ColumnType::new(CDT::uint64_datatype(), true), //sum(number) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end + ColumnType::new(CDT::uint64_datatype(), true), //sum(number) ]) .with_key(vec![1]) .with_time_index(Some(0)) - .with_autos(&[1]) - .into_named(vec![ - Some("window_start".to_string()), - Some("window_end".to_string()), - None, - ]), + .into_unnamed(), ), ), mfp: MapFilterProject::new(3) .map(vec![ ScalarExpr::Column(2), - ScalarExpr::Column(3), ScalarExpr::Column(0), ScalarExpr::Column(1), ]) .unwrap() - .project(vec![4, 5, 6]) + .project(vec![3, 4, 5]) .unwrap(), }, }; - assert_eq!(expected, flow_plan); + assert_eq!(flow_plan, expected); } #[tokio::test] @@ -648,21 +538,20 @@ mod test { .unwrap(); let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end ]) .with_key(vec![2]) .with_time_index(Some(1)) - .with_autos(&[2]) .into_named(vec![ - None, + Some("abs(SUM(numbers_with_ts.number))".to_string()), Some("window_start".to_string()), Some("window_end".to_string()), ]), @@ -682,7 +571,9 @@ mod test { Some("number".to_string()), Some("ts".to_string()), ]), - ), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(2) @@ -713,7 +604,9 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(2) - .project(vec![0, 1]) + .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())]) + .unwrap() + .project(vec![2]) .unwrap() .into_safe(), }, @@ -725,23 +618,17 @@ mod test { } .with_types( RelationType::new(vec![ - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end - ColumnType::new(CDT::uint64_datatype(), true), //sum(number) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end + ColumnType::new(CDT::uint64_datatype(), true), //sum(number) ]) .with_key(vec![1]) .with_time_index(Some(0)) - .with_autos(&[1]) - .into_named(vec![ - Some("window_start".to_string()), - Some("window_end".to_string()), - None, - ]), + .into_named(vec![None, None, None]), ), ), mfp: MapFilterProject::new(3) .map(vec![ - ScalarExpr::Column(2), ScalarExpr::CallDf { df_scalar_fn: DfScalarFunction::try_from_raw_fn(RawDfScalarFn { f: BytesMut::from(b"\"\x08\x1a\x06\x12\x04\n\x02\x12\0".as_ref()), @@ -753,24 +640,25 @@ mod test { extensions: FunctionExtensions { anchor_to_name: BTreeMap::from([ (0, "abs".to_string()), - (1, "tumble".to_string()), - (2, "sum".to_string()), + (1, "tumble_start".to_string()), + (2, "tumble_end".to_string()), + (3, "sum".to_string()), ]), }, }) .await .unwrap(), - exprs: vec![ScalarExpr::Column(3)], + exprs: vec![ScalarExpr::Column(2)], }, ScalarExpr::Column(0), ScalarExpr::Column(1), ]) .unwrap() - .project(vec![4, 5, 6]) + .project(vec![3, 4, 5]) .unwrap(), }, }; - assert_eq!(expected, flow_plan); + assert_eq!(flow_plan, expected); } /// TODO(discord9): add more illegal sql tests @@ -788,13 +676,13 @@ mod test { let aggr_exprs = vec![ AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }, AggregateExpr { func: AggregateFunc::Count, - expr: ScalarExpr::Column(0), + expr: ScalarExpr::Column(1), distinct: false, }, ]; @@ -803,11 +691,15 @@ mod test { ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()), BinaryFunc::NotEq, )), - then: Box::new(ScalarExpr::Column(3).call_binary( - ScalarExpr::Column(4).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), - BinaryFunc::DivUInt64, - )), - els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), + then: Box::new( + ScalarExpr::Column(3) + .cast(CDT::float64_datatype()) + .call_binary( + ScalarExpr::Column(4).cast(CDT::float64_datatype()), + BinaryFunc::DivFloat64, + ), + ), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())), }; let expected = TypedPlan { plan: Plan::Mfp { @@ -826,7 +718,9 @@ mod test { Some("number".to_string()), Some("ts".to_string()), ]), - ), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(2) @@ -858,7 +752,12 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(2) - .project(vec![0, 1]) + .map(vec![ + ScalarExpr::Column(0).cast(CDT::uint64_datatype()), + ScalarExpr::Column(0), + ]) + .unwrap() + .project(vec![2, 3]) .unwrap() .into_safe(), }, @@ -866,7 +765,7 @@ mod test { full_aggrs: aggr_exprs.clone(), simple_aggrs: vec![ AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1), + AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), ], distinct_aggrs: vec![], }), @@ -874,19 +773,18 @@ mod test { .with_types( RelationType::new(vec![ // keys - ColumnType::new(CDT::datetime_datatype(), false), // window start(time index) - ColumnType::new(CDT::datetime_datatype(), false), // window end(pk) - ColumnType::new(CDT::uint32_datatype(), false), // number(pk) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start(time index) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end(pk) + ColumnType::new(CDT::uint32_datatype(), false), // number(pk) // values ColumnType::new(CDT::uint64_datatype(), true), // avg.sum(number) ColumnType::new(CDT::int64_datatype(), true), // avg.count(number) ]) .with_key(vec![1, 2]) .with_time_index(Some(0)) - .with_autos(&[1]) .into_named(vec![ - Some("window_start".to_string()), - Some("window_end".to_string()), + None, + None, Some("number".to_string()), None, None, @@ -895,28 +793,26 @@ mod test { ), mfp: MapFilterProject::new(5) .map(vec![ - avg_expr, ScalarExpr::Column(2), // number(pk) - ScalarExpr::Column(5), // avg.sum(number) + avg_expr, ScalarExpr::Column(0), // window start ScalarExpr::Column(1), // window end ]) .unwrap() - .project(vec![6, 7, 8, 9]) + .project(vec![5, 6, 7, 8]) .unwrap(), }, schema: RelationType::new(vec![ ColumnType::new(CDT::uint32_datatype(), false), // number - ColumnType::new(CDT::uint64_datatype(), true), // avg(number) - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end + ColumnType::new(CDT::float64_datatype(), true), // avg(number) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end ]) .with_key(vec![0, 3]) .with_time_index(Some(2)) - .with_autos(&[3]) .into_named(vec![ - Some("number".to_string()), - None, + Some("numbers_with_ts.number".to_string()), + Some("AVG(numbers_with_ts.number)".to_string()), Some("window_start".to_string()), Some("window_end".to_string()), ]), @@ -936,21 +832,20 @@ mod test { .unwrap(); let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end ]) .with_key(vec![2]) .with_time_index(Some(1)) - .with_autos(&[2]) .into_named(vec![ - None, + Some("SUM(numbers_with_ts.number)".to_string()), Some("window_start".to_string()), Some("window_end".to_string()), ]), @@ -970,7 +865,9 @@ mod test { Some("number".to_string()), Some("ts".to_string()), ]), - ), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(2) @@ -1001,7 +898,9 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(2) - .project(vec![0, 1]) + .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())]) + .unwrap() + .project(vec![2]) .unwrap() .into_safe(), }, @@ -1013,29 +912,23 @@ mod test { } .with_types( RelationType::new(vec![ - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end - ColumnType::new(CDT::uint64_datatype(), true), //sum(number) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end + ColumnType::new(CDT::uint64_datatype(), true), //sum(number) ]) .with_key(vec![1]) .with_time_index(Some(0)) - .with_autos(&[1]) - .into_named(vec![ - Some("window_start".to_string()), - Some("window_end".to_string()), - None, - ]), + .into_named(vec![None, None, None]), ), ), mfp: MapFilterProject::new(3) .map(vec![ ScalarExpr::Column(2), - ScalarExpr::Column(3), ScalarExpr::Column(0), ScalarExpr::Column(1), ]) .unwrap() - .project(vec![4, 5, 6]) + .project(vec![3, 4, 5]) .unwrap(), }, }; @@ -1054,21 +947,20 @@ mod test { .unwrap(); let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end ]) .with_key(vec![2]) .with_time_index(Some(1)) - .with_autos(&[2]) .into_named(vec![ - None, + Some("SUM(numbers_with_ts.number)".to_string()), Some("window_start".to_string()), Some("window_end".to_string()), ]), @@ -1088,7 +980,9 @@ mod test { Some("number".to_string()), Some("ts".to_string()), ]), - ), + ) + .mfp(MapFilterProject::new(2).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(2) @@ -1119,7 +1013,9 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(2) - .project(vec![0, 1]) + .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())]) + .unwrap() + .project(vec![2]) .unwrap() .into_safe(), }, @@ -1131,29 +1027,23 @@ mod test { } .with_types( RelationType::new(vec![ - ColumnType::new(CDT::datetime_datatype(), false), // window start - ColumnType::new(CDT::datetime_datatype(), false), // window end - ColumnType::new(CDT::uint64_datatype(), true), //sum(number) + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start + ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end + ColumnType::new(CDT::uint64_datatype(), true), //sum(number) ]) .with_key(vec![1]) .with_time_index(Some(0)) - .with_autos(&[1]) - .into_named(vec![ - Some("window_start".to_string()), - Some("window_end".to_string()), - None, - ]), + .into_unnamed(), ), ), mfp: MapFilterProject::new(3) .map(vec![ ScalarExpr::Column(2), - ScalarExpr::Column(3), ScalarExpr::Column(0), ScalarExpr::Column(1), ]) .unwrap() - .project(vec![4, 5, 6]) + .project(vec![3, 4, 5]) .unwrap(), }, }; @@ -1171,13 +1061,13 @@ mod test { let aggr_exprs = vec![ AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }, AggregateExpr { func: AggregateFunc::Count, - expr: ScalarExpr::Column(0), + expr: ScalarExpr::Column(1), distinct: false, }, ]; @@ -1186,19 +1076,26 @@ mod test { ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()), BinaryFunc::NotEq, )), - then: Box::new(ScalarExpr::Column(1).call_binary( - ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), - BinaryFunc::DivUInt64, - )), - els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), + then: Box::new( + ScalarExpr::Column(1) + .cast(CDT::float64_datatype()) + .call_binary( + ScalarExpr::Column(2).cast(CDT::float64_datatype()), + BinaryFunc::DivFloat64, + ), + ), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())), }; let expected = TypedPlan { schema: RelationType::new(vec![ - ColumnType::new(CDT::uint64_datatype(), true), // sum(number) -> u64 + ColumnType::new(CDT::float64_datatype(), true), // avg(number: u32) -> f64 ColumnType::new(CDT::uint32_datatype(), false), // number ]) .with_key(vec![1]) - .into_named(vec![None, Some("number".to_string())]), + .into_named(vec![ + Some("AVG(numbers.number)".to_string()), + Some("numbers.number".to_string()), + ]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { @@ -1212,7 +1109,14 @@ mod test { false, )]) .into_named(vec![Some("number".to_string())]), - ), + ) + .mfp( + MapFilterProject::new(1) + .project(vec![0]) + .unwrap() + .into_safe(), + ) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(1) @@ -1222,7 +1126,12 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(1) - .project(vec![0]) + .map(vec![ + ScalarExpr::Column(0).cast(CDT::uint64_datatype()), + ScalarExpr::Column(0), + ]) + .unwrap() + .project(vec![1, 2]) .unwrap() .into_safe(), }, @@ -1230,7 +1139,7 @@ mod test { full_aggrs: aggr_exprs.clone(), simple_aggrs: vec![ AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1), + AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), ], distinct_aggrs: vec![], }), @@ -1252,12 +1161,11 @@ mod test { mfp: MapFilterProject::new(3) .map(vec![ avg_expr, // col 3 + ScalarExpr::Column(0), // TODO(discord9): optimize mfp so to remove indirect ref - ScalarExpr::Column(3), // col 4 - ScalarExpr::Column(0), // col 5 ]) .unwrap() - .project(vec![4, 5]) + .project(vec![3, 4]) .unwrap(), }, }; @@ -1278,13 +1186,13 @@ mod test { let aggr_exprs = vec![ AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }, AggregateExpr { func: AggregateFunc::Count, - expr: ScalarExpr::Column(0), + expr: ScalarExpr::Column(1), distinct: false, }, ]; @@ -1293,25 +1201,42 @@ mod test { ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()), BinaryFunc::NotEq, )), - then: Box::new(ScalarExpr::Column(0).call_binary( - ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), - BinaryFunc::DivUInt64, - )), - els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), + then: Box::new( + ScalarExpr::Column(0) + .cast(CDT::float64_datatype()) + .call_binary( + ScalarExpr::Column(1).cast(CDT::float64_datatype()), + BinaryFunc::DivFloat64, + ), + ), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())), }; + let input = Box::new( + Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + } + .with_types( + RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint32_datatype(), + false, + )]) + .into_named(vec![Some("number".to_string())]), + ), + ); let expected = TypedPlan { - schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) - .into_named(vec![None]), + schema: RelationType::new(vec![ColumnType::new(CDT::float64_datatype(), true)]) + .into_named(vec![Some("AVG(numbers.number)".to_string())]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { input: Box::new( - Plan::Get { - id: crate::expr::Id::Global(GlobalId::User(0)), + Plan::Mfp { + input: input.clone(), + mfp: MapFilterProject::new(1).project(vec![0]).unwrap(), } .with_types( RelationType::new(vec![ColumnType::new( - ConcreteDataType::uint32_datatype(), + CDT::uint32_datatype(), false, )]) .into_named(vec![Some("number".to_string())]), @@ -1323,7 +1248,12 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(1) - .project(vec![0]) + .map(vec![ + ScalarExpr::Column(0).cast(CDT::uint64_datatype()), + ScalarExpr::Column(0), + ]) + .unwrap() + .project(vec![1, 2]) .unwrap() .into_safe(), }, @@ -1331,7 +1261,7 @@ mod test { full_aggrs: aggr_exprs.clone(), simple_aggrs: vec![ AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1), + AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), ], distinct_aggrs: vec![], }), @@ -1348,10 +1278,9 @@ mod test { .map(vec![ avg_expr, // TODO(discord9): optimize mfp so to remove indirect ref - ScalarExpr::Column(2), ]) .unwrap() - .project(vec![3]) + .project(vec![2]) .unwrap(), }, }; @@ -1366,56 +1295,48 @@ mod test { let mut ctx = create_test_ctx(); let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; - let typ = RelationType::new(vec![ColumnType::new( - ConcreteDataType::uint64_datatype(), - true, - )]); + let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) - .into_unnamed(), - plan: Plan::Mfp { + .into_named(vec![Some("SUM(numbers.number)".to_string())]), + plan: Plan::Reduce { input: Box::new( - Plan::Reduce { - input: Box::new( - Plan::Get { - id: crate::expr::Id::Global(GlobalId::User(0)), - } - .with_types( - RelationType::new(vec![ColumnType::new( - ConcreteDataType::uint32_datatype(), - false, - )]) - .into_named(vec![Some("number".to_string())]), - ), - ), - key_val_plan: KeyValPlan { - key_plan: MapFilterProject::new(1) - .project(vec![]) - .unwrap() - .into_safe(), - val_plan: MapFilterProject::new(1) - .project(vec![0]) - .unwrap() - .into_safe(), - }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], - }), + Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), } - .with_types(typ.into_unnamed()), - ), - mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0), ScalarExpr::Column(1)]) - .unwrap() - .project(vec![2]) + .with_types( + RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint32_datatype(), + false, + )]) + .into_named(vec![Some("number".to_string())]), + ) + .mfp(MapFilterProject::new(1).into_safe()) .unwrap(), + ), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .project(vec![]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0) + .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))]) + .unwrap() + .project(vec![1]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: vec![aggr_expr.clone()], + simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], + distinct_aggrs: vec![], + }), }, }; assert_eq!(flow_plan.unwrap(), expected); @@ -1433,7 +1354,7 @@ mod test { .unwrap(); let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; @@ -1443,7 +1364,10 @@ mod test { ColumnType::new(CDT::uint32_datatype(), false), // col number ]) .with_key(vec![1]) - .into_named(vec![None, Some("number".to_string())]), + .into_named(vec![ + Some("SUM(numbers.number)".to_string()), + Some("numbers.number".to_string()), + ]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { @@ -1457,7 +1381,9 @@ mod test { false, )]) .into_named(vec![Some("number".to_string())]), - ), + ) + .mfp(MapFilterProject::new(1).into_safe()) + .unwrap(), ), key_val_plan: KeyValPlan { key_plan: MapFilterProject::new(1) @@ -1467,7 +1393,10 @@ mod test { .unwrap() .into_safe(), val_plan: MapFilterProject::new(1) - .project(vec![0]) + .map(vec![ScalarExpr::Column(0) + .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))]) + .unwrap() + .project(vec![1]) .unwrap() .into_safe(), }, @@ -1487,13 +1416,9 @@ mod test { ), ), mfp: MapFilterProject::new(2) - .map(vec![ - ScalarExpr::Column(1), - ScalarExpr::Column(2), - ScalarExpr::Column(0), - ]) + .map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)]) .unwrap() - .project(vec![3, 4]) + .project(vec![2, 3]) .unwrap(), }, }; @@ -1511,16 +1436,18 @@ mod test { let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt32, + func: AggregateFunc::SumUInt64, expr: ScalarExpr::Column(0), distinct: false, }; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) - .into_unnamed(), - plan: Plan::Mfp { + .into_named(vec![Some( + "SUM(numbers.number + numbers.number)".to_string(), + )]), + plan: Plan::Reduce { input: Box::new( - Plan::Reduce { + Plan::Mfp { input: Box::new( Plan::Get { id: crate::expr::Id::Global(GlobalId::User(0)), @@ -1533,35 +1460,35 @@ mod test { .into_named(vec![Some("number".to_string())]), ), ), - key_val_plan: KeyValPlan { - key_plan: MapFilterProject::new(1) - .project(vec![]) - .unwrap() - .into_safe(), - val_plan: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0) - .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)]) - .unwrap() - .project(vec![1]) - .unwrap() - .into_safe(), - }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], - }), + mfp: MapFilterProject::new(1), } .with_types( - RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) - .into_unnamed(), + RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint32_datatype(), + false, + )]) + .into_named(vec![Some("number".to_string())]), ), ), - mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0), ScalarExpr::Column(1)]) - .unwrap() - .project(vec![2]) - .unwrap(), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .project(vec![]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0) + .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32) + .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))]) + .unwrap() + .project(vec![1]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: vec![aggr_expr.clone()], + simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], + distinct_aggrs: vec![], + }), }, }; assert_eq!(flow_plan.unwrap(), expected); diff --git a/src/flow/src/transform/expr.rs b/src/flow/src/transform/expr.rs index 5848dc66b674..b2e24f5fc35d 100644 --- a/src/flow/src/transform/expr.rs +++ b/src/flow/src/transform/expr.rs @@ -20,7 +20,7 @@ use common_error::ext::BoxedError; use common_telemetry::debug; use datafusion_physical_expr::PhysicalExpr; use datatypes::data_type::ConcreteDataType as CDT; -use snafu::{OptionExt, ResultExt}; +use snafu::{ensure, OptionExt, ResultExt}; use substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference; use substrait_proto::proto::expression::reference_segment::ReferenceType::StructField; use substrait_proto::proto::expression::{IfThen, RexType, ScalarFunction}; @@ -167,6 +167,16 @@ fn rewrite_scalar_function( arg_typed_exprs: &[TypedExpr], ) -> Result { let mut f_rewrite = f.clone(); + ensure!( + f_rewrite.arguments.len() == arg_typed_exprs.len(), + crate::error::InternalSnafu { + reason: format!( + "Expect `f_rewrite` and `arg_typed_expr` to be same length, found {} and {}", + f_rewrite.arguments.len(), + arg_typed_exprs.len() + ) + } + ); for (idx, raw_expr) in f_rewrite.arguments.iter_mut().enumerate() { // only replace it with col(idx) if it is not literal // will try best to determine if it is literal, i.e. for function like `cast()` will try @@ -351,7 +361,13 @@ impl TypedExpr { Ok(TypedExpr::new(ret_expr, ret_type)) } _var => { - if VariadicFunc::is_valid_func_name(fn_name) { + if fn_name == "tumble_start" || fn_name == "tumble_end" { + let (func, arg) = UnaryFunc::from_tumble_func(fn_name, &arg_typed_exprs)?; + + let ret_type = ColumnType::new_nullable(func.signature().output.clone()); + + Ok(TypedExpr::new(arg.expr.call_unary(func), ret_type)) + } else if VariadicFunc::is_valid_func_name(fn_name) { let func = VariadicFunc::from_str_and_types(fn_name, &arg_types)?; let ret_type = ColumnType::new_nullable(func.signature().output.clone()); let mut expr = ScalarExpr::CallVariadic { @@ -562,7 +578,7 @@ mod test { }; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]) - .into_named(vec![Some("number".to_string())]), + .into_named(vec![Some("numbers.number".to_string())]), plan: Plan::Mfp { input: Box::new( Plan::Get { @@ -576,13 +592,7 @@ mod test { .into_named(vec![Some("number".to_string())]), ), ), - mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0)]) - .unwrap() - .filter(vec![filter]) - .unwrap() - .project(vec![1]) - .unwrap(), + mfp: MapFilterProject::new(1).filter(vec![filter]).unwrap(), }, }; assert_eq!(flow_plan.unwrap(), expected); @@ -600,7 +610,7 @@ mod test { let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)]) - .into_unnamed(), + .into_named(vec![Some("Int64(1) + Int64(1) * Int64(2) - Int64(1) / Int64(1) + Int64(1) % Int64(2) = Int64(3)".to_string())]), plan: Plan::Constant { rows: vec![( repr::Row::new(vec![Value::from(true)]), @@ -624,8 +634,8 @@ mod test { let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { - schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]) - .into_unnamed(), + schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)]) + .into_named(vec![Some("numbers.number + Int64(1)".to_string())]), plan: Plan::Mfp { input: Box::new( Plan::Get { @@ -640,10 +650,12 @@ mod test { ), ), mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0).call_binary( - ScalarExpr::Literal(Value::from(1u32), CDT::uint32_datatype()), - BinaryFunc::AddUInt32, - )]) + .map(vec![ScalarExpr::Column(0) + .call_unary(UnaryFunc::Cast(CDT::int64_datatype())) + .call_binary( + ScalarExpr::Literal(Value::from(1i64), CDT::int64_datatype()), + BinaryFunc::AddInt64, + )]) .unwrap() .project(vec![1]) .unwrap(), @@ -663,7 +675,9 @@ mod test { let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)]) - .into_unnamed(), + .into_named(vec![Some( + "arrow_cast(Int64(1),Utf8(\"Int16\"))".to_string(), + )]), plan: Plan::Constant { // cast of literal is constant folded rows: vec![(repr::Row::new(vec![Value::from(1i16)]), i64::MIN, 1)], @@ -683,7 +697,7 @@ mod test { let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]) - .into_unnamed(), + .into_named(vec![Some("numbers.number + numbers.number".to_string())]), plan: Plan::Mfp { input: Box::new( Plan::Get { diff --git a/src/flow/src/transform/literal.rs b/src/flow/src/transform/literal.rs index 255ceadb54ca..e1da5c8276a5 100644 --- a/src/flow/src/transform/literal.rs +++ b/src/flow/src/transform/literal.rs @@ -341,6 +341,8 @@ pub fn from_substrait_type(null_type: &substrait_proto::proto::Type) -> Result plan_err!("Cannot parse plan relation: None") @@ -117,17 +124,6 @@ impl TypedPlan { plan, }) } else { - match input.plan.clone() { - Plan::Reduce { key_val_plan, .. } => { - rewrite_projection_after_reduce(key_val_plan, &input.schema, &mut exprs)?; - } - Plan::Mfp { input, mfp: _ } => { - if let Plan::Reduce { key_val_plan, .. } = input.plan { - rewrite_projection_after_reduce(key_val_plan, &input.schema, &mut exprs)?; - } - } - _ => (), - } input.projection(exprs) } } @@ -235,113 +231,6 @@ impl TypedPlan { } } -/// if reduce_plan contains the special function like tumble floor/ceiling, add them to the proj_exprs -/// so the effect is the window_start, window_end column are auto added to output rows -/// -/// This is to fix a problem that we have certain functions that return two values, but since substrait doesn't know that, it will assume it return one value -/// this function fix that and rewrite `proj_exprs` to correct form -fn rewrite_projection_after_reduce( - key_val_plan: KeyValPlan, - reduce_output_type: &RelationDesc, - proj_exprs: &mut Vec, -) -> Result<(), Error> { - // TODO(discord9): get keys correctly - let key_exprs = key_val_plan - .key_plan - .projection - .clone() - .into_iter() - .map(|i| { - if i < key_val_plan.key_plan.input_arity { - ScalarExpr::Column(i) - } else { - key_val_plan.key_plan.expressions[i - key_val_plan.key_plan.input_arity].clone() - } - }) - .collect_vec(); - let mut shift_offset = 0; - let mut shuffle: BTreeMap = BTreeMap::new(); - let special_keys = key_exprs - .clone() - .into_iter() - .enumerate() - .filter(|(idx, p)| { - shuffle.insert(*idx, *idx + shift_offset); - if matches!( - p, - ScalarExpr::CallUnary { - func: UnaryFunc::TumbleWindowFloor { .. }, - .. - } | ScalarExpr::CallUnary { - func: UnaryFunc::TumbleWindowCeiling { .. }, - .. - } - ) { - if matches!( - p, - ScalarExpr::CallUnary { - func: UnaryFunc::TumbleWindowFloor { .. }, - .. - } - ) { - shift_offset += 1; - } - true - } else { - false - } - }) - .collect_vec(); - let spec_key_arity = special_keys.len(); - if spec_key_arity == 0 { - return Ok(()); - } - - // shuffle proj_exprs - // because substrait use offset while assume `tumble` only return one value - for proj_expr in proj_exprs.iter_mut() { - proj_expr.expr.permute_map(&shuffle)?; - } // add key to the end - for (key_idx, _key_expr) in special_keys { - // here we assume the output type of reduce operator(`reduce_output_type`) is just first keys columns, then append value columns - // so we can use `key_idx` to index `reduce_output_type` and get the keys we need to append to `proj_exprs` - proj_exprs.push( - ScalarExpr::Column(key_idx) - .with_type(reduce_output_type.typ().column_types[key_idx].clone()), - ); - } - - // check if normal expr in group exprs are all in proj_exprs - let all_cols_ref_in_proj: BTreeSet = proj_exprs - .iter() - .filter_map(|e| { - if let ScalarExpr::Column(i) = &e.expr { - Some(*i) - } else { - None - } - }) - .collect(); - for (key_idx, key_expr) in key_exprs.iter().enumerate() { - if let ScalarExpr::Column(_) = key_expr { - if !all_cols_ref_in_proj.contains(&key_idx) { - let err_msg = format!( - "Expect normal column in group by also appear in projection, but column {}(name is {}) is missing", - key_idx, - reduce_output_type - .get_name(key_idx) - .clone() - .map(|s|format!("'{}'",s)) - .unwrap_or("unknown".to_string()) - ); - return InvalidQuerySnafu { reason: err_msg }.fail(); - } - } - } - - Ok(()) -} - #[cfg(test)] mod test { use datatypes::prelude::ConcreteDataType; @@ -365,7 +254,7 @@ mod test { let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]) - .into_named(vec![Some("number".to_string())]), + .into_named(vec![Some("numbers.number".to_string())]), plan: Plan::Mfp { input: Box::new( Plan::Get { @@ -379,11 +268,7 @@ mod test { .into_named(vec![Some("number".to_string())]), ), ), - mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0)]) - .unwrap() - .project(vec![1]) - .unwrap(), + mfp: MapFilterProject::new(1), }, };