diff --git a/src/common/substrait/src/lib.rs b/src/common/substrait/src/lib.rs index e0c3046b0868..51f8119dbb41 100644 --- a/src/common/substrait/src/lib.rs +++ b/src/common/substrait/src/lib.rs @@ -17,12 +17,12 @@ mod df_substrait; pub mod error; pub mod extension_serializer; - use std::sync::Arc; use async_trait::async_trait; use bytes::{Buf, Bytes}; use datafusion::catalog::CatalogList; +pub use substrait_proto; pub use crate::df_substrait::DFLogicalSubstraitConvertor; diff --git a/src/flow/Cargo.toml b/src/flow/Cargo.toml index d5be926121a8..4e9f2e0002ab 100644 --- a/src/flow/Cargo.toml +++ b/src/flow/Cargo.toml @@ -29,6 +29,7 @@ servers.workspace = true smallvec.workspace = true snafu.workspace = true strum.workspace = true +substrait.workspace = true tokio.workspace = true tonic.workspace = true @@ -39,5 +40,4 @@ prost.workspace = true query.workspace = true serde_json = "1.0" session.workspace = true -substrait.workspace = true table.workspace = true diff --git a/src/flow/src/adapter/error.rs b/src/flow/src/adapter/error.rs index ea5ea39f1356..dc384fe6c9ff 100644 --- a/src/flow/src/adapter/error.rs +++ b/src/flow/src/adapter/error.rs @@ -73,6 +73,13 @@ pub enum Error { extra: String, location: Location, }, + + #[snafu(display("Datafusion error: {raw:?} in context: {context}"))] + Datafusion { + raw: datafusion_common::DataFusionError, + context: String, + location: Location, + }, } /// Result type for flow module @@ -81,7 +88,9 @@ pub type Result = std::result::Result; impl ErrorExt for Error { fn status_code(&self) -> StatusCode { match self { - Self::Eval { .. } | &Self::JoinTask { .. } => StatusCode::Internal, + Self::Eval { .. } | &Self::JoinTask { .. } | &Self::Datafusion { .. } => { + StatusCode::Internal + } &Self::TableAlreadyExist { .. } => StatusCode::TableAlreadyExists, Self::TableNotFound { .. } => StatusCode::TableNotFound, &Self::InvalidQuery { .. } | &Self::Plan { .. } | &Self::Datatypes { .. } => { diff --git a/src/flow/src/compute/render.rs b/src/flow/src/compute/render.rs index 708297d56f54..f2b02e219adc 100644 --- a/src/flow/src/compute/render.rs +++ b/src/flow/src/compute/render.rs @@ -344,7 +344,7 @@ mod test { (Row::new(vec![2i64.into()]), 2, 1), (Row::new(vec![3i64.into()]), 3, 1), ]; - let collection = ctx.render_constant(rows.clone()); + let collection = ctx.render_constant(rows); ctx.insert_global(GlobalId::User(1), collection); let input_plan = Plan::Get { id: expr::Id::Global(GlobalId::User(1)), @@ -440,7 +440,7 @@ mod test { (Row::new(vec![2.into()]), 2, 1), (Row::new(vec![3.into()]), 3, 1), ]; - let collection = ctx.render_constant(rows.clone()); + let collection = ctx.render_constant(rows); ctx.insert_global(GlobalId::User(1), collection); let input_plan = Plan::Get { id: expr::Id::Global(GlobalId::User(1)), @@ -490,7 +490,7 @@ mod test { (Row::empty(), 2, 1), (Row::empty(), 3, 1), ]; - let collection = ctx.render_constant(rows.clone()); + let collection = ctx.render_constant(rows); let collection = collection.collection.clone(ctx.df); let cnt = Rc::new(RefCell::new(0)); let cnt_inner = cnt.clone(); diff --git a/src/flow/src/expr.rs b/src/flow/src/expr.rs index 4550234b4e2e..7fb2ba7f29ae 100644 --- a/src/flow/src/expr.rs +++ b/src/flow/src/expr.rs @@ -27,4 +27,4 @@ pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc} pub(crate) use id::{GlobalId, Id, LocalId}; pub(crate) use linear::{MapFilterProject, MfpPlan, SafeMfpPlan}; pub(crate) use relation::{AggregateExpr, AggregateFunc}; -pub(crate) use scalar::ScalarExpr; +pub(crate) use scalar::{ScalarExpr, TypedExpr}; diff --git a/src/flow/src/expr/func.rs b/src/flow/src/expr/func.rs index 06b64f7c4a35..518bb14aded9 100644 --- a/src/flow/src/expr/func.rs +++ b/src/flow/src/expr/func.rs @@ -501,8 +501,8 @@ impl BinaryFunc { let spec_fn = Self::specialization(generic_fn, query_input_type)?; let signature = Signature { - input: smallvec![arg_type.clone(), arg_type.clone()], - output: spec_fn.signature().output.clone(), + input: smallvec![arg_type.clone(), arg_type], + output: spec_fn.signature().output, generic_fn, }; @@ -767,7 +767,7 @@ fn test_num_ops() { assert_eq!(res, Value::from(30)); let res = div::(left.clone(), right.clone()).unwrap(); assert_eq!(res, Value::from(3)); - let res = rem::(left.clone(), right.clone()).unwrap(); + let res = rem::(left, right).unwrap(); assert_eq!(res, Value::from(1)); let values = vec![Value::from(true), Value::from(false)]; diff --git a/src/flow/src/expr/relation/func.rs b/src/flow/src/expr/relation/func.rs index 8c765024c6d9..17751423aab0 100644 --- a/src/flow/src/expr/relation/func.rs +++ b/src/flow/src/expr/relation/func.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::str::FromStr; use std::sync::OnceLock; use common_time::{Date, DateTime}; @@ -20,10 +21,10 @@ use datatypes::prelude::ConcreteDataType; use datatypes::value::{OrderedF32, OrderedF64, Value}; use serde::{Deserialize, Serialize}; use smallvec::smallvec; -use snafu::OptionExt; +use snafu::{OptionExt, ResultExt}; use strum::{EnumIter, IntoEnumIterator}; -use crate::adapter::error::{Error, InvalidQuerySnafu}; +use crate::adapter::error::{DatafusionSnafu, Error, InvalidQuerySnafu}; use crate::expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu}; use crate::expr::relation::accum::{Accum, Accumulator}; use crate::expr::signature::{GenericFn, Signature}; @@ -172,17 +173,32 @@ impl AggregateFunc { } spec }); + use datafusion_expr::aggregate_function::AggregateFunction as DfAggrFunc; + let df_aggr_func = DfAggrFunc::from_str(name).or_else(|err| { + if let datafusion_common::DataFusionError::NotImplemented(msg) = err { + InvalidQuerySnafu { + reason: format!("Unsupported aggregate function: {}", msg), + } + .fail() + } else { + DatafusionSnafu { + raw: err, + context: "Error when parsing aggregate function", + } + .fail() + } + })?; - let generic_fn = match name { - "max" => GenericFn::Max, - "min" => GenericFn::Min, - "sum" => GenericFn::Sum, - "count" => GenericFn::Count, - "any" => GenericFn::Any, - "all" => GenericFn::All, + let generic_fn = match df_aggr_func { + DfAggrFunc::Max => GenericFn::Max, + DfAggrFunc::Min => GenericFn::Min, + DfAggrFunc::Sum => GenericFn::Sum, + DfAggrFunc::Count => GenericFn::Count, + DfAggrFunc::BoolOr => GenericFn::Any, + DfAggrFunc::BoolAnd => GenericFn::All, _ => { return InvalidQuerySnafu { - reason: format!("Unknown binary function: {}", name), + reason: format!("Unknown aggregate function: {}", name), } .fail(); } diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 772bb06a4a90..0f979f65edb7 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -24,6 +24,22 @@ use snafu::ensure; use crate::adapter::error::{Error, InvalidQuerySnafu, UnsupportedTemporalFilterSnafu}; use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu}; use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}; +use crate::repr::ColumnType; + +/// A scalar expression with a known type. +#[derive(Debug, Clone)] +pub struct TypedExpr { + /// The expression. + pub expr: ScalarExpr, + /// The type of the expression. + pub typ: ColumnType, +} + +impl TypedExpr { + pub fn new(expr: ScalarExpr, typ: ColumnType) -> Self { + Self { expr, typ } + } +} /// A scalar expression, which can be evaluated to a value. #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -64,6 +80,38 @@ pub enum ScalarExpr { }, } +impl ScalarExpr { + /// apply optimization to the expression, like flatten variadic function + pub fn optimize(&mut self) { + self.flatten_varidic_fn(); + } + + /// Because Substrait's `And`/`Or` function is binary, but FlowPlan's + /// `And`/`Or` function is variadic, we need to flatten the `And` function if multiple `And`/`Or` functions are nested. + fn flatten_varidic_fn(&mut self) { + if let ScalarExpr::CallVariadic { func, exprs } = self { + let mut new_exprs = vec![]; + for expr in std::mem::take(exprs) { + if let ScalarExpr::CallVariadic { + func: inner_func, + exprs: mut inner_exprs, + } = expr + { + if *func == inner_func { + for inner_expr in inner_exprs.iter_mut() { + inner_expr.flatten_varidic_fn(); + } + new_exprs.extend(inner_exprs); + } + } else { + new_exprs.push(expr); + } + } + *exprs = new_exprs; + } + } +} + impl ScalarExpr { /// Call a unary function on this expression. pub fn call_unary(self, func: UnaryFunc) -> Self { diff --git a/src/flow/src/lib.rs b/src/flow/src/lib.rs index e16bcf45ec58..a7392a69f498 100644 --- a/src/flow/src/lib.rs +++ b/src/flow/src/lib.rs @@ -27,4 +27,5 @@ mod compute; mod expr; mod plan; mod repr; +mod transform; mod utils; diff --git a/src/flow/src/plan.rs b/src/flow/src/plan.rs index 2bf1301d0719..51a73d81a48d 100644 --- a/src/flow/src/plan.rs +++ b/src/flow/src/plan.rs @@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize}; pub(crate) use self::reduce::{AccumulablePlan, KeyValPlan, ReducePlan}; use crate::adapter::error::Error; use crate::expr::{ - AggregateExpr, EvalError, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr, + AggregateExpr, EvalError, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr, TypedExpr, }; use crate::plan::join::JoinPlan; use crate::repr::{ColumnType, DiffRow, RelationType}; @@ -61,10 +61,13 @@ impl TypedPlan { } /// project the plan to the given expressions - pub fn projection(self, exprs: Vec<(ScalarExpr, ColumnType)>) -> Result { + pub fn projection(self, exprs: Vec) -> Result { let input_arity = self.typ.column_types.len(); let output_arity = exprs.len(); - let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs.into_iter().unzip(); + let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs + .into_iter() + .map(|TypedExpr { expr, typ }| (expr, typ)) + .unzip(); let mfp = MapFilterProject::new(input_arity) .map(exprs)? .project(input_arity..input_arity + output_arity)?; @@ -87,18 +90,19 @@ impl TypedPlan { } /// Add a new filter to the plan, will filter out the records that do not satisfy the filter - pub fn filter(self, filter: (ScalarExpr, ColumnType)) -> Result { + pub fn filter(self, filter: TypedExpr) -> Result { let plan = match self.plan { Plan::Mfp { input, mfp: old_mfp, } => Plan::Mfp { input, - mfp: old_mfp.filter(vec![filter.0])?, + mfp: old_mfp.filter(vec![filter.expr])?, }, _ => Plan::Mfp { input: Box::new(self.plan), - mfp: MapFilterProject::new(self.typ.column_types.len()).filter(vec![filter.0])?, + mfp: MapFilterProject::new(self.typ.column_types.len()) + .filter(vec![filter.expr])?, }, }; Ok(TypedPlan { diff --git a/src/flow/src/transform.rs b/src/flow/src/transform.rs new file mode 100644 index 000000000000..bc1b84cb04fa --- /dev/null +++ b/src/flow/src/transform.rs @@ -0,0 +1,179 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Transform Substrait into execution plan +use std::collections::HashMap; + +use datatypes::data_type::ConcreteDataType as CDT; + +use crate::adapter::error::{Error, NotImplementedSnafu, TableNotFoundSnafu}; +use crate::expr::GlobalId; +use crate::repr::RelationType; +/// a simple macro to generate a not implemented error +macro_rules! not_impl_err { + ($($arg:tt)*) => { + NotImplementedSnafu { + reason: format!($($arg)*), + }.fail() + }; +} + +/// generate a plan error +macro_rules! plan_err { + ($($arg:tt)*) => { + PlanSnafu { + reason: format!($($arg)*), + }.fail() + }; +} + +mod aggr; +mod expr; +mod literal; +mod plan; + +use literal::{from_substrait_literal, from_substrait_type}; +use snafu::OptionExt; +use substrait::substrait_proto::proto::extensions::simple_extension_declaration::MappingType; +use substrait::substrait_proto::proto::extensions::SimpleExtensionDeclaration; + +/// In Substrait, a function can be define by an u32 anchor, and the anchor can be mapped to a name +/// +/// So in substrait plan, a ref to a function can be a single u32 anchor instead of a full name in string +pub struct FunctionExtensions { + anchor_to_name: HashMap, +} + +impl FunctionExtensions { + /// Create a new FunctionExtensions from a list of SimpleExtensionDeclaration + pub fn try_from_proto(extensions: &[SimpleExtensionDeclaration]) -> Result { + let mut anchor_to_name = HashMap::new(); + for e in extensions { + match &e.mapping_type { + Some(ext) => match ext { + MappingType::ExtensionFunction(ext_f) => { + anchor_to_name.insert(ext_f.function_anchor, ext_f.name.clone()); + } + _ => not_impl_err!("Extension type not supported: {ext:?}")?, + }, + None => not_impl_err!("Cannot parse empty extension")?, + } + } + Ok(Self { anchor_to_name }) + } + + /// Get the name of a function by it's anchor + pub fn get(&self, anchor: &u32) -> Option<&String> { + self.anchor_to_name.get(anchor) + } +} + +/// A context that holds the information of the dataflow +pub struct DataflowContext { + /// `id` refer to any source table in the dataflow, and `name` is the name of the table + /// which is a `Vec` in substrait + id_to_name: HashMap>, + /// see `id_to_name` + name_to_id: HashMap, GlobalId>, + /// the schema of the table + schema: HashMap, +} + +impl DataflowContext { + /// Retrieves a GlobalId and table schema representing a table previously registered by calling the [register_table] function. + /// + /// Returns an error if no table has been registered with the provided names + pub fn table(&self, name: &Vec) -> Result<(GlobalId, RelationType), Error> { + let id = self + .name_to_id + .get(name) + .copied() + .with_context(|| TableNotFoundSnafu { + name: name.join("."), + })?; + let schema = self + .schema + .get(&id) + .cloned() + .with_context(|| TableNotFoundSnafu { + name: name.join("."), + })?; + Ok((id, schema)) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use catalog::RegisterTableRequest; + use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID}; + use prost::Message; + use query::parser::QueryLanguageParser; + use query::plan::LogicalPlan; + use query::QueryEngine; + use session::context::QueryContext; + use substrait::substrait_proto::proto; + use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; + use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME}; + + use super::*; + use crate::repr::ColumnType; + + pub fn create_test_ctx() -> DataflowContext { + let gid = GlobalId::User(0); + let name = vec!["numbers".to_string()]; + let schema = RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]); + + DataflowContext { + id_to_name: HashMap::from([(gid, name.clone())]), + name_to_id: HashMap::from([(name.clone(), gid)]), + schema: HashMap::from([(gid, schema)]), + } + } + + pub fn create_test_query_engine() -> Arc { + let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap(); + let req = RegisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name: NUMBERS_TABLE_NAME.to_string(), + table_id: NUMBERS_TABLE_ID, + table: NumbersTable::table(NUMBERS_TABLE_ID), + }; + catalog_list.register_table_sync(req).unwrap(); + let factory = query::QueryEngineFactory::new(catalog_list, None, None, None, false); + + let engine = factory.query_engine(); + + assert_eq!("datafusion", engine.name()); + engine + } + + pub async fn sql_to_substrait(engine: Arc, sql: &str) -> proto::Plan { + // let engine = create_test_query_engine(); + let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap(); + let plan = engine + .planner() + .plan(stmt, QueryContext::arc()) + .await + .unwrap(); + let LogicalPlan::DfPlan(plan) = plan; + + // encode then decode so to rely on the impl of conversion from logical plan to substrait plan + let bytes = DFLogicalSubstraitConvertor {}.encode(&plan).unwrap(); + + proto::Plan::decode(bytes).unwrap() + } +} diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs new file mode 100644 index 000000000000..7a320dddc80f --- /dev/null +++ b/src/flow/src/transform/aggr.rs @@ -0,0 +1,446 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use common_decimal::Decimal128; +use common_time::{Date, Timestamp}; +use datafusion_substrait::variation_const::{ + DATE_32_TYPE_REF, DATE_64_TYPE_REF, DEFAULT_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF, + TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF, + UNSIGNED_INTEGER_TYPE_REF, +}; +use datatypes::arrow::compute::kernels::window; +use datatypes::arrow::ipc::Binary; +use datatypes::data_type::ConcreteDataType as CDT; +use datatypes::value::Value; +use hydroflow::futures::future::Map; +use itertools::Itertools; +use snafu::{OptionExt, ResultExt}; +use substrait::substrait_proto::proto::aggregate_function::AggregationInvocation; +use substrait::substrait_proto::proto::aggregate_rel::{Grouping, Measure}; +use substrait::substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference; +use substrait::substrait_proto::proto::expression::literal::LiteralType; +use substrait::substrait_proto::proto::expression::reference_segment::ReferenceType::StructField; +use substrait::substrait_proto::proto::expression::{ + IfThen, Literal, MaskExpression, RexType, ScalarFunction, +}; +use substrait::substrait_proto::proto::extensions::simple_extension_declaration::MappingType; +use substrait::substrait_proto::proto::extensions::SimpleExtensionDeclaration; +use substrait::substrait_proto::proto::function_argument::ArgType; +use substrait::substrait_proto::proto::r#type::Kind; +use substrait::substrait_proto::proto::read_rel::ReadType; +use substrait::substrait_proto::proto::rel::RelType; +use substrait::substrait_proto::proto::{self, plan_rel, Expression, Plan as SubPlan, Rel}; + +use crate::adapter::error::{ + DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu, + TableNotFoundSnafu, +}; +use crate::expr::{ + AggregateExpr, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, SafeMfpPlan, ScalarExpr, + TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc, +}; +use crate::plan::{AccumulablePlan, KeyValPlan, Plan, ReducePlan, TypedPlan}; +use crate::repr::{self, ColumnType, RelationType}; +use crate::transform::{DataflowContext, FunctionExtensions}; + +impl TypedExpr { + fn from_substrait_agg_grouping( + ctx: &mut DataflowContext, + groupings: &[Grouping], + typ: &RelationType, + extensions: &FunctionExtensions, + ) -> Result, Error> { + let _ = ctx; + let mut group_expr = vec![]; + match groupings.len() { + 1 => { + for e in &groupings[0].grouping_expressions { + let x = TypedExpr::from_substrait_rex(e, typ, extensions)?; + group_expr.push(x); + } + } + _ => { + return not_impl_err!( + "Grouping sets not support yet, use union all with group by instead." + ); + } + }; + Ok(group_expr) + } +} + +impl AggregateExpr { + fn from_substrait_agg_measures( + ctx: &mut DataflowContext, + measures: &[Measure], + typ: &RelationType, + extensions: &FunctionExtensions, + ) -> Result, Error> { + let _ = ctx; + let mut aggr_exprs = vec![]; + + for m in measures { + let filter = &m + .filter + .as_ref() + .map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions)) + .transpose()?; + + let agg_func = match &m.measure { + Some(f) => { + let distinct = match f.invocation { + _ if f.invocation == AggregationInvocation::Distinct as i32 => true, + _ if f.invocation == AggregationInvocation::All as i32 => false, + _ => false, + }; + AggregateExpr::from_substrait_agg_func( + f, typ, extensions, filter, // TODO(discord9): impl order_by + &None, distinct, + ) + } + None => not_impl_err!("Aggregate without aggregate function is not supported"), + }?; + aggr_exprs.push(agg_func); + } + Ok(aggr_exprs) + } + + /// Convert AggregateFunction into Flow's AggregateExpr + pub fn from_substrait_agg_func( + f: &proto::AggregateFunction, + input_schema: &RelationType, + extensions: &FunctionExtensions, + filter: &Option, + order_by: &Option>, + distinct: bool, + ) -> Result { + // TODO(discord9): impl filter + let _ = filter; + let _ = order_by; + let mut args = vec![]; + for arg in &f.arguments { + let arg_expr = match &arg.arg_type { + Some(ArgType::Value(e)) => { + TypedExpr::from_substrait_rex(e, input_schema, extensions) + } + _ => not_impl_err!("Aggregated function argument non-Value type not supported"), + }?; + args.push(arg_expr); + } + + let arg = if let Some(first) = args.first() { + first + } else { + return not_impl_err!("Aggregated function without arguments is not supported"); + }; + + let func = match extensions.get(&f.function_reference) { + Some(function_name) => { + AggregateFunc::from_str_and_type(function_name, Some(arg.typ.scalar_type.clone())) + } + None => not_impl_err!( + "Aggregated function not found: function anchor = {:?}", + f.function_reference + ), + }?; + Ok(AggregateExpr { + func, + expr: arg.expr.clone(), + distinct, + }) + } +} + +impl KeyValPlan { + /// Generate KeyValPlan from AggregateExpr and group_exprs + /// + /// will also change aggregate expr to use column ref if necessary + fn from_substrait_gen_key_val_plan( + aggr_exprs: &mut [AggregateExpr], + group_exprs: &[TypedExpr], + input_arity: usize, + ) -> Result { + let group_expr_val = group_exprs + .iter() + .cloned() + .map(|expr| expr.expr.clone()) + .collect_vec(); + let output_arity = group_expr_val.len(); + let key_plan = MapFilterProject::new(input_arity) + .map(group_expr_val)? + .project(input_arity..input_arity + output_arity)?; + + // val_plan is extracted from aggr_exprs to give aggr function it's necessary input + // and since aggr func need inputs that is column ref, we just add a prefix mfp to transform any expr that is not into a column ref + let val_plan = { + let need_mfp = aggr_exprs.iter().any(|agg| agg.expr.as_column().is_none()); + if need_mfp { + // create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp + let input_exprs = aggr_exprs + .iter_mut() + .enumerate() + .map(|(idx, aggr)| { + let ret = aggr.expr.clone(); + aggr.expr = ScalarExpr::Column(idx); + ret + }) + .collect_vec(); + let aggr_arity = aggr_exprs.len(); + + MapFilterProject::new(input_arity) + .map(input_exprs)? + .project(input_arity..input_arity + aggr_arity)? + } else { + // simply take all inputs as value + MapFilterProject::new(input_arity) + } + }; + Ok(KeyValPlan { + key_plan: key_plan.into_safe(), + val_plan: val_plan.into_safe(), + }) + } +} + +impl TypedPlan { + /// Convert AggregateRel into Flow's TypedPlan + pub fn from_substrait_agg_rel( + ctx: &mut DataflowContext, + agg: &proto::AggregateRel, + extensions: &FunctionExtensions, + ) -> Result { + let input = if let Some(input) = agg.input.as_ref() { + TypedPlan::from_substrait_rel(ctx, input, extensions)? + } else { + return not_impl_err!("Aggregate without an input is not supported"); + }; + + let group_expr = + TypedExpr::from_substrait_agg_grouping(ctx, &agg.groupings, &input.typ, extensions)?; + + let mut aggr_exprs = + AggregateExpr::from_substrait_agg_measures(ctx, &agg.measures, &input.typ, extensions)?; + + let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan( + &mut aggr_exprs, + &group_expr, + input.typ.column_types.len(), + )?; + + let output_type = { + let mut output_types = Vec::new(); + // first append group_expr as key, then aggr_expr as value + for expr in &group_expr { + output_types.push(expr.typ.clone()); + } + + for aggr in &aggr_exprs { + output_types.push(ColumnType::new_nullable( + aggr.func.signature().output.clone(), + )); + } + RelationType::new(output_types) + }; + + // copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs + // also set them input/output column + let full_aggrs = aggr_exprs; + let mut simple_aggrs = Vec::new(); + let mut distinct_aggrs = Vec::new(); + for (output_column, aggr_expr) in full_aggrs.iter().enumerate() { + let input_column = aggr_expr.expr.as_column().with_context(|| PlanSnafu { + reason: "Expect aggregate argument to be transformed into a column at this point", + })?; + if aggr_expr.distinct { + distinct_aggrs.push((output_column, input_column, aggr_expr.clone())); + } else { + simple_aggrs.push((output_column, input_column, aggr_expr.clone())); + } + } + let accum_plan = AccumulablePlan { + full_aggrs, + simple_aggrs, + distinct_aggrs, + }; + let plan = Plan::Reduce { + input: Box::new(input.plan), + key_val_plan, + reduce_plan: ReducePlan::Accumulable(accum_plan), + }; + Ok(TypedPlan { + typ: output_type, + plan, + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::plan::{Plan, TypedPlan}; + use crate::repr::{self, ColumnType, RelationType}; + use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait}; + + #[tokio::test] + async fn test_sum() { + let engine = create_test_query_engine(); + let sql = "SELECT sum(number) FROM numbers"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + let aggr_expr = AggregateExpr { + func: AggregateFunc::SumUInt32, + expr: ScalarExpr::Column(0), + distinct: false, + }; + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]), + plan: Plan::Mfp { + input: Box::new(Plan::Reduce { + input: Box::new(Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + }), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .project(vec![]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .project(vec![0]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: vec![aggr_expr.clone()], + simple_aggrs: vec![(0, 0, aggr_expr.clone())], + distinct_aggrs: vec![], + }), + }), + mfp: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0)]) + .unwrap() + .project(vec![1]) + .unwrap(), + }, + }; + assert_eq!(flow_plan.unwrap(), expected); + } + + #[tokio::test] + async fn test_sum_group_by() { + let engine = create_test_query_engine(); + let sql = "SELECT sum(number), number FROM numbers GROUP BY number"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + + let aggr_expr = AggregateExpr { + func: AggregateFunc::SumUInt32, + expr: ScalarExpr::Column(0), + distinct: false, + }; + let expected = TypedPlan { + typ: RelationType::new(vec![ + ColumnType::new(CDT::uint32_datatype(), true), + ColumnType::new(CDT::uint32_datatype(), false), + ]), + plan: Plan::Mfp { + input: Box::new(Plan::Reduce { + input: Box::new(Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + }), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0)]) + .unwrap() + .project(vec![1]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .project(vec![0]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: vec![aggr_expr.clone()], + simple_aggrs: vec![(0, 0, aggr_expr.clone())], + distinct_aggrs: vec![], + }), + }), + mfp: MapFilterProject::new(2) + .map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)]) + .unwrap() + .project(vec![2, 3]) + .unwrap(), + }, + }; + + assert_eq!(flow_plan, expected); + } + + #[tokio::test] + async fn test_sum_add() { + let engine = create_test_query_engine(); + let sql = "SELECT sum(number+number) FROM numbers"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + let aggr_expr = AggregateExpr { + func: AggregateFunc::SumUInt32, + expr: ScalarExpr::Column(0), + distinct: false, + }; + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]), + plan: Plan::Mfp { + input: Box::new(Plan::Reduce { + input: Box::new(Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + }), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .project(vec![]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0) + .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)]) + .unwrap() + .project(vec![1]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: vec![aggr_expr.clone()], + simple_aggrs: vec![(0, 0, aggr_expr.clone())], + distinct_aggrs: vec![], + }), + }), + mfp: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0)]) + .unwrap() + .project(vec![1]) + .unwrap(), + }, + }; + assert_eq!(flow_plan.unwrap(), expected); + } +} diff --git a/src/flow/src/transform/expr.rs b/src/flow/src/transform/expr.rs new file mode 100644 index 000000000000..3f65c4b607fd --- /dev/null +++ b/src/flow/src/transform/expr.rs @@ -0,0 +1,449 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![warn(unused_imports)] + +use datatypes::data_type::ConcreteDataType as CDT; +use itertools::Itertools; +use snafu::{OptionExt, ResultExt}; +use substrait::substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference; +use substrait::substrait_proto::proto::expression::reference_segment::ReferenceType::StructField; +use substrait::substrait_proto::proto::expression::{IfThen, RexType, ScalarFunction}; +use substrait::substrait_proto::proto::function_argument::ArgType; +use substrait::substrait_proto::proto::Expression; + +use crate::adapter::error::{ + DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu, +}; +use crate::expr::{ + BinaryFunc, ScalarExpr, TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc, +}; +use crate::repr::{ColumnType, RelationType}; +use crate::transform::literal::{from_substrait_literal, from_substrait_type}; +use crate::transform::FunctionExtensions; + +impl TypedExpr { + /// Convert ScalarFunction into Flow's ScalarExpr + pub fn from_substrait_scalar_func( + f: &ScalarFunction, + input_schema: &RelationType, + extensions: &FunctionExtensions, + ) -> Result { + let fn_name = + extensions + .get(&f.function_reference) + .with_context(|| NotImplementedSnafu { + reason: format!( + "Aggregated function not found: function reference = {:?}", + f.function_reference + ), + })?; + let arg_len = f.arguments.len(); + let arg_exprs: Vec = f + .arguments + .iter() + .map(|arg| match &arg.arg_type { + Some(ArgType::Value(e)) => { + TypedExpr::from_substrait_rex(e, input_schema, extensions) + } + _ => not_impl_err!("Aggregated function argument non-Value type not supported"), + }) + .try_collect()?; + + // literal's type is determined by the function and type of other args + let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_exprs + .into_iter() + .map( + |TypedExpr { + expr: arg_val, + typ: arg_type, + }| { + if arg_val.is_literal() { + (arg_val, None) + } else { + (arg_val, Some(arg_type.scalar_type)) + } + }, + ) + .unzip(); + + match arg_len { + // because variadic function can also have 1 arguments, we need to check if it's a variadic function first + 1 if VariadicFunc::from_str_and_types(fn_name, &arg_types).is_err() => { + let func = UnaryFunc::from_str_and_type(fn_name, None)?; + let arg = arg_exprs[0].clone(); + let ret_type = ColumnType::new_nullable(func.signature().output.clone()); + + Ok(TypedExpr::new(arg.call_unary(func), ret_type)) + } + // because variadic function can also have 2 arguments, we need to check if it's a variadic function first + 2 if VariadicFunc::from_str_and_types(fn_name, &arg_types).is_err() => { + let (func, signature) = + BinaryFunc::from_str_expr_and_type(fn_name, &arg_exprs, &arg_types[0..2])?; + + // constant folding here + let is_all_literal = arg_exprs.iter().all(|arg| arg.is_literal()); + if is_all_literal { + let res = func + .eval(&[], &arg_exprs[0], &arg_exprs[1]) + .context(EvalSnafu)?; + + // if output type is null, it should be inferred from the input types + let con_typ = signature.output.clone(); + let typ = ColumnType::new_nullable(con_typ.clone()); + return Ok(TypedExpr::new(ScalarExpr::Literal(res, con_typ), typ)); + } + + let mut arg_exprs = arg_exprs; + for (idx, arg_expr) in arg_exprs.iter_mut().enumerate() { + if let ScalarExpr::Literal(val, typ) = arg_expr { + let dest_type = signature.input[idx].clone(); + + // cast val to target_type + let dest_val = if !dest_type.is_null() { + datatypes::types::cast(val.clone(), &dest_type) + .with_context(|_| + DatatypesSnafu{ + extra: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}") + })? + } else { + val.clone() + }; + *val = dest_val; + *typ = dest_type; + } + } + + let ret_type = ColumnType::new_nullable(func.signature().output.clone()); + let ret_expr = arg_exprs[0].clone().call_binary(arg_exprs[1].clone(), func); + Ok(TypedExpr::new(ret_expr, ret_type)) + } + _var => { + if let Ok(func) = VariadicFunc::from_str_and_types(fn_name, &arg_types) { + let ret_type = ColumnType::new_nullable(func.signature().output.clone()); + let mut expr = ScalarExpr::CallVariadic { + func, + exprs: arg_exprs, + }; + expr.optimize(); + Ok(TypedExpr::new(expr, ret_type)) + } else if let Ok(func) = UnmaterializableFunc::from_str(fn_name) { + let ret_type = ColumnType::new_nullable(func.signature().output.clone()); + Ok(TypedExpr::new( + ScalarExpr::CallUnmaterializable(func), + ret_type, + )) + } else { + not_impl_err!("Unsupported function {fn_name} with {arg_len} arguments") + } + } + } + } + + /// Convert IfThen into Flow's ScalarExpr + pub fn from_substrait_ifthen_rex( + if_then: &IfThen, + input_schema: &RelationType, + extensions: &FunctionExtensions, + ) -> Result { + let ifs: Vec<_> = if_then + .ifs + .iter() + .map(|if_clause| { + let proto_if = if_clause.r#if.as_ref().with_context(|| InvalidQuerySnafu { + reason: "IfThen clause without if", + })?; + let proto_then = if_clause.then.as_ref().with_context(|| InvalidQuerySnafu { + reason: "IfThen clause without then", + })?; + let cond = TypedExpr::from_substrait_rex(proto_if, input_schema, extensions)?; + let then = TypedExpr::from_substrait_rex(proto_then, input_schema, extensions)?; + Ok((cond, then)) + }) + .try_collect()?; + // if no else is presented + let els = if_then + .r#else + .as_ref() + .map(|e| TypedExpr::from_substrait_rex(e, input_schema, extensions)) + .transpose()? + .unwrap_or_else(|| { + TypedExpr::new( + ScalarExpr::literal_null(), + ColumnType::new_nullable(CDT::null_datatype()), + ) + }); + + fn build_if_then_recur( + mut next_if_then: impl Iterator, + els: TypedExpr, + ) -> TypedExpr { + if let Some((cond, then)) = next_if_then.next() { + // always assume the type of `if`` expr is the same with the `then`` expr + TypedExpr::new( + ScalarExpr::If { + cond: Box::new(cond.expr), + then: Box::new(then.expr), + els: Box::new(build_if_then_recur(next_if_then, els).expr), + }, + then.typ, + ) + } else { + els + } + } + let expr_if = build_if_then_recur(ifs.into_iter(), els); + Ok(expr_if) + } + /// Convert Substrait Rex into Flow's ScalarExpr + pub fn from_substrait_rex( + e: &Expression, + input_schema: &RelationType, + extensions: &FunctionExtensions, + ) -> Result { + match &e.rex_type { + Some(RexType::Literal(lit)) => { + let lit = from_substrait_literal(lit)?; + Ok(TypedExpr::new( + ScalarExpr::Literal(lit.0, lit.1.clone()), + ColumnType::new_nullable(lit.1), + )) + } + Some(RexType::SingularOrList(s)) => { + let substrait_expr = s.value.as_ref().with_context(|| InvalidQuerySnafu { + reason: "SingularOrList expression without value", + })?; + // Note that we didn't impl support to in list expr + if !s.options.is_empty() { + return not_impl_err!("In list expression is not supported"); + } + TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions) + } + Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => { + not_impl_err!( + "Direct reference StructField with child is not supported" + ) + } + None => { + let column = x.field as usize; + let column_type = input_schema.column_types[column].clone(); + Ok(TypedExpr::new(ScalarExpr::Column(column), column_type)) + } + }, + _ => not_impl_err!( + "Direct reference with types other than StructField is not supported" + ), + }, + _ => not_impl_err!("unsupported field ref type"), + }, + Some(RexType::ScalarFunction(f)) => { + TypedExpr::from_substrait_scalar_func(f, input_schema, extensions) + } + Some(RexType::IfThen(if_then)) => { + TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions) + } + Some(RexType::Cast(cast)) => { + let input = cast.input.as_ref().with_context(|| InvalidQuerySnafu { + reason: "Cast expression without input", + })?; + let input = TypedExpr::from_substrait_rex(input, input_schema, extensions)?; + let cast_type = from_substrait_type(cast.r#type.as_ref().with_context(|| { + InvalidQuerySnafu { + reason: "Cast expression without type", + } + })?)?; + let func = UnaryFunc::from_str_and_type("cast", Some(cast_type.clone()))?; + Ok(TypedExpr::new( + input.expr.call_unary(func), + ColumnType::new_nullable(cast_type), + )) + } + Some(RexType::WindowFunction(_)) => PlanSnafu { + reason: + "Window function is not supported yet. Please use aggregation function instead." + .to_string(), + } + .fail(), + _ => not_impl_err!("unsupported rex_type"), + } + } +} + +#[cfg(test)] +mod test { + use datatypes::value::Value; + + use super::*; + use crate::expr::{GlobalId, MapFilterProject}; + use crate::plan::{Plan, TypedPlan}; + use crate::repr::{self, ColumnType, RelationType}; + use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait}; + /// test if `WHERE` condition can be converted to Flow's ScalarExpr in mfp's filter + #[tokio::test] + async fn test_where_and() { + let engine = create_test_query_engine(); + let sql = "SELECT number FROM numbers WHERE number >= 1 AND number <= 3 AND number!=2"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + // optimize binary and to variadic and + let filter = ScalarExpr::CallVariadic { + func: VariadicFunc::And, + exprs: vec![ + ScalarExpr::Column(0).call_binary( + ScalarExpr::Literal(Value::from(1u32), CDT::uint32_datatype()), + BinaryFunc::Gte, + ), + ScalarExpr::Column(0).call_binary( + ScalarExpr::Literal(Value::from(3u32), CDT::uint32_datatype()), + BinaryFunc::Lte, + ), + ScalarExpr::Column(0).call_binary( + ScalarExpr::Literal(Value::from(2u32), CDT::uint32_datatype()), + BinaryFunc::NotEq, + ), + ], + }; + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]), + plan: Plan::Mfp { + input: Box::new(Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + }), + mfp: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0)]) + .unwrap() + .filter(vec![filter]) + .unwrap() + .project(vec![1]) + .unwrap(), + }, + }; + assert_eq!(flow_plan.unwrap(), expected); + } + + /// case: binary functions&constant folding can happen in converting substrait plan + #[tokio::test] + async fn test_binary_func_and_constant_folding() { + let engine = create_test_query_engine(); + let sql = "SELECT 1+1*2-1/1+1%2==3 FROM numbers"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)]), + plan: Plan::Constant { + rows: vec![( + repr::Row::new(vec![Value::from(true)]), + repr::Timestamp::MIN, + 1, + )], + }, + }; + + assert_eq!(flow_plan.unwrap(), expected); + } + + /// test if the type of the literal is correctly inferred, i.e. in here literal is decoded to be int64, but need to be uint32, + #[tokio::test] + async fn test_implicitly_cast() { + let engine = create_test_query_engine(); + let sql = "SELECT number+1 FROM numbers"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]), + plan: Plan::Mfp { + input: Box::new(Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + }), + mfp: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0).call_binary( + ScalarExpr::Literal(Value::from(1u32), CDT::uint32_datatype()), + BinaryFunc::AddUInt32, + )]) + .unwrap() + .project(vec![1]) + .unwrap(), + }, + }; + assert_eq!(flow_plan.unwrap(), expected); + } + + #[tokio::test] + async fn test_cast() { + let engine = create_test_query_engine(); + let sql = "SELECT CAST(1 AS INT16) FROM numbers"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)]), + plan: Plan::Mfp { + input: Box::new(Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + }), + mfp: MapFilterProject::new(1) + .map(vec![ScalarExpr::Literal( + Value::Int64(1), + CDT::int64_datatype(), + ) + .call_unary(UnaryFunc::Cast(CDT::int16_datatype()))]) + .unwrap() + .project(vec![1]) + .unwrap(), + }, + }; + assert_eq!(flow_plan.unwrap(), expected); + } + + #[tokio::test] + async fn test_select_add() { + let engine = create_test_query_engine(); + let sql = "SELECT number+number FROM numbers"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]), + plan: Plan::Mfp { + input: Box::new(Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + }), + mfp: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0) + .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)]) + .unwrap() + .project(vec![1]) + .unwrap(), + }, + }; + + assert_eq!(flow_plan.unwrap(), expected); + } +} diff --git a/src/flow/src/transform/literal.rs b/src/flow/src/transform/literal.rs new file mode 100644 index 000000000000..b41a82e26a4b --- /dev/null +++ b/src/flow/src/transform/literal.rs @@ -0,0 +1,191 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_decimal::Decimal128; +use common_time::{Date, Timestamp}; +use datafusion_substrait::variation_const::{ + DATE_32_TYPE_REF, DATE_64_TYPE_REF, DEFAULT_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF, + TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF, + UNSIGNED_INTEGER_TYPE_REF, +}; +use datatypes::data_type::ConcreteDataType as CDT; +use datatypes::value::Value; +use substrait::substrait_proto::proto::expression::literal::LiteralType; +use substrait::substrait_proto::proto::expression::Literal; +use substrait::substrait_proto::proto::r#type::Kind; + +use crate::adapter::error::{Error, NotImplementedSnafu, PlanSnafu}; + +/// Convert a Substrait literal into a Value and its ConcreteDataType (So that we can know type even if the value is null) +pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<(Value, CDT), Error> { + let scalar_value = match &lit.literal_type { + Some(LiteralType::Boolean(b)) => (Value::from(*b), CDT::boolean_datatype()), + Some(LiteralType::I8(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => (Value::from(*n as i8), CDT::int8_datatype()), + UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u8), CDT::uint8_datatype()), + others => not_impl_err!("Unknown type variation reference {others}",)?, + }, + Some(LiteralType::I16(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => (Value::from(*n as i16), CDT::int16_datatype()), + UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u16), CDT::uint16_datatype()), + others => not_impl_err!("Unknown type variation reference {others}",)?, + }, + Some(LiteralType::I32(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => (Value::from(*n), CDT::int32_datatype()), + UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u32), CDT::uint32_datatype()), + others => not_impl_err!("Unknown type variation reference {others}",)?, + }, + Some(LiteralType::I64(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => (Value::from(*n), CDT::int64_datatype()), + UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u64), CDT::uint64_datatype()), + others => not_impl_err!("Unknown type variation reference {others}",)?, + }, + Some(LiteralType::Fp32(f)) => (Value::from(*f), CDT::float32_datatype()), + Some(LiteralType::Fp64(f)) => (Value::from(*f), CDT::float64_datatype()), + Some(LiteralType::Timestamp(t)) => match lit.type_variation_reference { + TIMESTAMP_SECOND_TYPE_REF => ( + Value::from(Timestamp::new_second(*t)), + CDT::timestamp_second_datatype(), + ), + TIMESTAMP_MILLI_TYPE_REF => ( + Value::from(Timestamp::new_millisecond(*t)), + CDT::timestamp_millisecond_datatype(), + ), + TIMESTAMP_MICRO_TYPE_REF => ( + Value::from(Timestamp::new_microsecond(*t)), + CDT::timestamp_microsecond_datatype(), + ), + TIMESTAMP_NANO_TYPE_REF => ( + Value::from(Timestamp::new_nanosecond(*t)), + CDT::timestamp_nanosecond_datatype(), + ), + others => not_impl_err!("Unknown type variation reference {others}",)?, + }, + Some(LiteralType::Date(d)) => (Value::from(Date::new(*d)), CDT::date_datatype()), + Some(LiteralType::String(s)) => (Value::from(s.clone()), CDT::string_datatype()), + Some(LiteralType::Binary(b)) | Some(LiteralType::FixedBinary(b)) => { + (Value::from(b.clone()), CDT::binary_datatype()) + } + Some(LiteralType::Decimal(d)) => { + let value: [u8; 16] = d.value.clone().try_into().map_err(|e| { + PlanSnafu { + reason: format!("Failed to parse decimal value from {e:?}"), + } + .build() + })?; + let p: u8 = d.precision.try_into().map_err(|e| { + PlanSnafu { + reason: format!("Failed to parse decimal precision: {e}"), + } + .build() + })?; + let s: i8 = d.scale.try_into().map_err(|e| { + PlanSnafu { + reason: format!("Failed to parse decimal scale: {e}"), + } + .build() + })?; + let value = i128::from_le_bytes(value); + ( + Value::from(Decimal128::new(value, p, s)), + CDT::decimal128_datatype(p, s), + ) + } + Some(LiteralType::Null(ntype)) => (Value::Null, from_substrait_type(ntype)?), + _ => not_impl_err!("unsupported literal_type")?, + }; + Ok(scalar_value) +} + +/// convert a Substrait type into a ConcreteDataType +pub fn from_substrait_type( + null_type: &substrait::substrait_proto::proto::Type, +) -> Result { + if let Some(kind) = &null_type.kind { + match kind { + Kind::Bool(_) => Ok(CDT::boolean_datatype()), + Kind::I8(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(CDT::int8_datatype()), + UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint8_datatype()), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + Kind::I16(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(CDT::int16_datatype()), + UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint16_datatype()), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + Kind::I32(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(CDT::int32_datatype()), + UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint32_datatype()), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + Kind::I64(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(CDT::int64_datatype()), + UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint64_datatype()), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + Kind::Fp32(_) => Ok(CDT::float32_datatype()), + Kind::Fp64(_) => Ok(CDT::float64_datatype()), + Kind::Timestamp(ts) => match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_REF => Ok(CDT::timestamp_second_datatype()), + TIMESTAMP_MILLI_TYPE_REF => Ok(CDT::timestamp_millisecond_datatype()), + TIMESTAMP_MICRO_TYPE_REF => Ok(CDT::timestamp_microsecond_datatype()), + TIMESTAMP_NANO_TYPE_REF => Ok(CDT::timestamp_nanosecond_datatype()), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + Kind::Date(date) => match date.type_variation_reference { + DATE_32_TYPE_REF => Ok(CDT::date_datatype()), + DATE_64_TYPE_REF => Ok(CDT::date_datatype()), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + Kind::Binary(_) => Ok(CDT::binary_datatype()), + Kind::String(_) => Ok(CDT::string_datatype()), + Kind::Decimal(d) => Ok(CDT::decimal128_datatype(d.precision as u8, d.scale as i8)), + _ => not_impl_err!("Unsupported Substrait type: {kind:?}"), + } + } else { + not_impl_err!("Null type without kind is not supported") + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::plan::{Plan, TypedPlan}; + use crate::repr::{self, ColumnType, RelationType}; + use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait}; + /// test if literal in substrait plan can be correctly converted to flow plan + #[tokio::test] + async fn test_literal() { + let engine = create_test_query_engine(); + let sql = "SELECT 1 FROM numbers"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)]), + plan: Plan::Constant { + rows: vec![( + repr::Row::new(vec![Value::Int64(1)]), + repr::Timestamp::MIN, + 1, + )], + }, + }; + + assert_eq!(flow_plan.unwrap(), expected); + } +} diff --git a/src/flow/src/transform/plan.rs b/src/flow/src/transform/plan.rs new file mode 100644 index 000000000000..fd73bb33d08d --- /dev/null +++ b/src/flow/src/transform/plan.rs @@ -0,0 +1,190 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; +use snafu::OptionExt; +use substrait::substrait_proto::proto::expression::MaskExpression; +use substrait::substrait_proto::proto::read_rel::ReadType; +use substrait::substrait_proto::proto::rel::RelType; +use substrait::substrait_proto::proto::{plan_rel, Plan as SubPlan, Rel}; + +use crate::adapter::error::{Error, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu}; +use crate::expr::{MapFilterProject, TypedExpr}; +use crate::plan::{Plan, TypedPlan}; +use crate::repr::{self, RelationType}; +use crate::transform::{DataflowContext, FunctionExtensions}; + +impl TypedPlan { + /// Convert Substrait Plan into Flow's TypedPlan + pub fn from_substrait_plan( + ctx: &mut DataflowContext, + plan: &SubPlan, + ) -> Result { + // Register function extension + let function_extension = FunctionExtensions::try_from_proto(&plan.extensions)?; + + // Parse relations + match plan.relations.len() { + 1 => { + match plan.relations[0].rel_type.as_ref() { + Some(rt) => match rt { + plan_rel::RelType::Rel(rel) => { + Ok(TypedPlan::from_substrait_rel(ctx, rel, &function_extension)?) + }, + plan_rel::RelType::Root(root) => { + let input = root.input.as_ref().with_context(|| InvalidQuerySnafu { + reason: "Root relation without input", + })?; + Ok(TypedPlan::from_substrait_rel(ctx, input, &function_extension)?) + } + }, + None => plan_err!("Cannot parse plan relation: None") + } + }, + _ => not_impl_err!( + "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", + plan.relations.len() + ) + } + } + + /// Convert Substrait Rel into Flow's TypedPlan + /// TODO: SELECT DISTINCT(does it get compile with something else?) + pub fn from_substrait_rel( + ctx: &mut DataflowContext, + rel: &Rel, + extensions: &FunctionExtensions, + ) -> Result { + match &rel.rel_type { + Some(RelType::Project(p)) => { + let input = if let Some(input) = p.input.as_ref() { + TypedPlan::from_substrait_rel(ctx, input, extensions)? + } else { + return not_impl_err!("Projection without an input is not supported"); + }; + let mut exprs: Vec = vec![]; + for e in &p.expressions { + let expr = TypedExpr::from_substrait_rex(e, &input.typ, extensions)?; + exprs.push(expr); + } + let is_literal = exprs.iter().all(|expr| expr.expr.is_literal()); + if is_literal { + let (literals, lit_types): (Vec<_>, Vec<_>) = exprs + .into_iter() + .map(|TypedExpr { expr, typ }| (expr, typ)) + .unzip(); + let typ = RelationType::new(lit_types); + let row = literals + .into_iter() + .map(|lit| lit.as_literal().expect("A literal")) + .collect_vec(); + let row = repr::Row::new(row); + let plan = Plan::Constant { + rows: vec![(row, repr::Timestamp::MIN, 1)], + }; + Ok(TypedPlan { typ, plan }) + } else { + input.projection(exprs) + } + } + Some(RelType::Filter(filter)) => { + let input = if let Some(input) = filter.input.as_ref() { + TypedPlan::from_substrait_rel(ctx, input, extensions)? + } else { + return not_impl_err!("Filter without an input is not supported"); + }; + + let expr = if let Some(condition) = filter.condition.as_ref() { + TypedExpr::from_substrait_rex(condition, &input.typ, extensions)? + } else { + return not_impl_err!("Filter without an condition is not valid"); + }; + input.filter(expr) + } + Some(RelType::Read(read)) => { + if let Some(ReadType::NamedTable(nt)) = &read.as_ref().read_type { + let table_reference = nt.names.clone(); + let table = ctx.table(&table_reference)?; + let get_table = Plan::Get { + id: crate::expr::Id::Global(table.0), + }; + let get_table = TypedPlan { + typ: table.1, + plan: get_table, + }; + + if let Some(MaskExpression { + select: Some(projection), + .. + }) = &read.projection + { + let column_indices: Vec = projection + .struct_items + .iter() + .map(|item| item.field as usize) + .collect(); + let input_arity = get_table.typ.column_types.len(); + let mfp = + MapFilterProject::new(input_arity).project(column_indices.clone())?; + get_table.mfp(mfp) + } else { + Ok(get_table) + } + } else { + not_impl_err!("Only NamedTable reads are supported") + } + } + Some(RelType::Aggregate(agg)) => { + TypedPlan::from_substrait_agg_rel(ctx, agg, extensions) + } + _ => not_impl_err!("Unsupported relation type: {:?}", rel.rel_type), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::expr::{GlobalId, ScalarExpr}; + use crate::plan::{Plan, TypedPlan}; + use crate::repr::{self, ColumnType, RelationType}; + use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait}; + use crate::transform::CDT; + + #[tokio::test] + async fn test_select() { + let engine = create_test_query_engine(); + let sql = "SELECT number FROM numbers"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]), + plan: Plan::Mfp { + input: Box::new(Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + }), + mfp: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0)]) + .unwrap() + .project(vec![1]) + .unwrap(), + }, + }; + + assert_eq!(flow_plan.unwrap(), expected); + } +}