Skip to content

Commit

Permalink
feat(flow): transform substrait SELECT&WHERE&GROUP BY to Flow Plan (#…
Browse files Browse the repository at this point in the history
…3690)

* feat: transofrm substrait SELECT&WHERE&GROUP BY to Flow Plan

* chore: reexport from common/substrait

* feat: use datafusion Aggr Func to map to Flow aggr func

* chore: remove unwrap&split literal

* refactor: split transform.rs into smaller files

* feat: apply optimize for variadic fn

* refactor: split unit test

* chore: per review
  • Loading branch information
discord9 authored Apr 12, 2024
1 parent 544c4a7 commit db329f6
Show file tree
Hide file tree
Showing 15 changed files with 1,559 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/common/substrait/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
mod df_substrait;
pub mod error;
pub mod extension_serializer;

use std::sync::Arc;

use async_trait::async_trait;
use bytes::{Buf, Bytes};
use datafusion::catalog::CatalogList;
pub use substrait_proto;

pub use crate::df_substrait::DFLogicalSubstraitConvertor;

Expand Down
2 changes: 1 addition & 1 deletion src/flow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ servers.workspace = true
smallvec.workspace = true
snafu.workspace = true
strum.workspace = true
substrait.workspace = true
tokio.workspace = true
tonic.workspace = true

Expand All @@ -39,5 +40,4 @@ prost.workspace = true
query.workspace = true
serde_json = "1.0"
session.workspace = true
substrait.workspace = true
table.workspace = true
11 changes: 10 additions & 1 deletion src/flow/src/adapter/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ pub enum Error {
extra: String,
location: Location,
},

#[snafu(display("Datafusion error: {raw:?} in context: {context}"))]
Datafusion {
raw: datafusion_common::DataFusionError,
context: String,
location: Location,
},
}

/// Result type for flow module
Expand All @@ -81,7 +88,9 @@ pub type Result<T> = std::result::Result<T, Error>;
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Self::Eval { .. } | &Self::JoinTask { .. } => StatusCode::Internal,
Self::Eval { .. } | &Self::JoinTask { .. } | &Self::Datafusion { .. } => {
StatusCode::Internal
}
&Self::TableAlreadyExist { .. } => StatusCode::TableAlreadyExists,
Self::TableNotFound { .. } => StatusCode::TableNotFound,
&Self::InvalidQuery { .. } | &Self::Plan { .. } | &Self::Datatypes { .. } => {
Expand Down
6 changes: 3 additions & 3 deletions src/flow/src/compute/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ mod test {
(Row::new(vec![2i64.into()]), 2, 1),
(Row::new(vec![3i64.into()]), 3, 1),
];
let collection = ctx.render_constant(rows.clone());
let collection = ctx.render_constant(rows);
ctx.insert_global(GlobalId::User(1), collection);
let input_plan = Plan::Get {
id: expr::Id::Global(GlobalId::User(1)),
Expand Down Expand Up @@ -440,7 +440,7 @@ mod test {
(Row::new(vec![2.into()]), 2, 1),
(Row::new(vec![3.into()]), 3, 1),
];
let collection = ctx.render_constant(rows.clone());
let collection = ctx.render_constant(rows);
ctx.insert_global(GlobalId::User(1), collection);
let input_plan = Plan::Get {
id: expr::Id::Global(GlobalId::User(1)),
Expand Down Expand Up @@ -490,7 +490,7 @@ mod test {
(Row::empty(), 2, 1),
(Row::empty(), 3, 1),
];
let collection = ctx.render_constant(rows.clone());
let collection = ctx.render_constant(rows);
let collection = collection.collection.clone(ctx.df);
let cnt = Rc::new(RefCell::new(0));
let cnt_inner = cnt.clone();
Expand Down
2 changes: 1 addition & 1 deletion src/flow/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}
pub(crate) use id::{GlobalId, Id, LocalId};
pub(crate) use linear::{MapFilterProject, MfpPlan, SafeMfpPlan};
pub(crate) use relation::{AggregateExpr, AggregateFunc};
pub(crate) use scalar::ScalarExpr;
pub(crate) use scalar::{ScalarExpr, TypedExpr};
6 changes: 3 additions & 3 deletions src/flow/src/expr/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,8 @@ impl BinaryFunc {
let spec_fn = Self::specialization(generic_fn, query_input_type)?;

let signature = Signature {
input: smallvec![arg_type.clone(), arg_type.clone()],
output: spec_fn.signature().output.clone(),
input: smallvec![arg_type.clone(), arg_type],
output: spec_fn.signature().output,
generic_fn,
};

Expand Down Expand Up @@ -767,7 +767,7 @@ fn test_num_ops() {
assert_eq!(res, Value::from(30));
let res = div::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(3));
let res = rem::<i32>(left.clone(), right.clone()).unwrap();
let res = rem::<i32>(left, right).unwrap();
assert_eq!(res, Value::from(1));

let values = vec![Value::from(true), Value::from(false)];
Expand Down
36 changes: 26 additions & 10 deletions src/flow/src/expr/relation/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
// limitations under the License.

use std::collections::HashMap;
use std::str::FromStr;
use std::sync::OnceLock;

use common_time::{Date, DateTime};
use datatypes::prelude::ConcreteDataType;
use datatypes::value::{OrderedF32, OrderedF64, Value};
use serde::{Deserialize, Serialize};
use smallvec::smallvec;
use snafu::OptionExt;
use snafu::{OptionExt, ResultExt};
use strum::{EnumIter, IntoEnumIterator};

use crate::adapter::error::{Error, InvalidQuerySnafu};
use crate::adapter::error::{DatafusionSnafu, Error, InvalidQuerySnafu};
use crate::expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu};
use crate::expr::relation::accum::{Accum, Accumulator};
use crate::expr::signature::{GenericFn, Signature};
Expand Down Expand Up @@ -172,17 +173,32 @@ impl AggregateFunc {
}
spec
});
use datafusion_expr::aggregate_function::AggregateFunction as DfAggrFunc;
let df_aggr_func = DfAggrFunc::from_str(name).or_else(|err| {
if let datafusion_common::DataFusionError::NotImplemented(msg) = err {
InvalidQuerySnafu {
reason: format!("Unsupported aggregate function: {}", msg),
}
.fail()
} else {
DatafusionSnafu {
raw: err,
context: "Error when parsing aggregate function",
}
.fail()
}
})?;

let generic_fn = match name {
"max" => GenericFn::Max,
"min" => GenericFn::Min,
"sum" => GenericFn::Sum,
"count" => GenericFn::Count,
"any" => GenericFn::Any,
"all" => GenericFn::All,
let generic_fn = match df_aggr_func {
DfAggrFunc::Max => GenericFn::Max,
DfAggrFunc::Min => GenericFn::Min,
DfAggrFunc::Sum => GenericFn::Sum,
DfAggrFunc::Count => GenericFn::Count,
DfAggrFunc::BoolOr => GenericFn::Any,
DfAggrFunc::BoolAnd => GenericFn::All,
_ => {
return InvalidQuerySnafu {
reason: format!("Unknown binary function: {}", name),
reason: format!("Unknown aggregate function: {}", name),
}
.fail();
}
Expand Down
48 changes: 48 additions & 0 deletions src/flow/src/expr/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ use snafu::ensure;
use crate::adapter::error::{Error, InvalidQuerySnafu, UnsupportedTemporalFilterSnafu};
use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu};
use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
use crate::repr::ColumnType;

/// A scalar expression with a known type.
#[derive(Debug, Clone)]
pub struct TypedExpr {
/// The expression.
pub expr: ScalarExpr,
/// The type of the expression.
pub typ: ColumnType,
}

impl TypedExpr {
pub fn new(expr: ScalarExpr, typ: ColumnType) -> Self {
Self { expr, typ }
}
}

/// A scalar expression, which can be evaluated to a value.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
Expand Down Expand Up @@ -64,6 +80,38 @@ pub enum ScalarExpr {
},
}

impl ScalarExpr {
/// apply optimization to the expression, like flatten variadic function
pub fn optimize(&mut self) {
self.flatten_varidic_fn();
}

/// Because Substrait's `And`/`Or` function is binary, but FlowPlan's
/// `And`/`Or` function is variadic, we need to flatten the `And` function if multiple `And`/`Or` functions are nested.
fn flatten_varidic_fn(&mut self) {
if let ScalarExpr::CallVariadic { func, exprs } = self {
let mut new_exprs = vec![];
for expr in std::mem::take(exprs) {
if let ScalarExpr::CallVariadic {
func: inner_func,
exprs: mut inner_exprs,
} = expr
{
if *func == inner_func {
for inner_expr in inner_exprs.iter_mut() {
inner_expr.flatten_varidic_fn();
}
new_exprs.extend(inner_exprs);
}
} else {
new_exprs.push(expr);
}
}
*exprs = new_exprs;
}
}
}

impl ScalarExpr {
/// Call a unary function on this expression.
pub fn call_unary(self, func: UnaryFunc) -> Self {
Expand Down
1 change: 1 addition & 0 deletions src/flow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ mod compute;
mod expr;
mod plan;
mod repr;
mod transform;
mod utils;
16 changes: 10 additions & 6 deletions src/flow/src/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize};
pub(crate) use self::reduce::{AccumulablePlan, KeyValPlan, ReducePlan};
use crate::adapter::error::Error;
use crate::expr::{
AggregateExpr, EvalError, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr,
AggregateExpr, EvalError, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr, TypedExpr,
};
use crate::plan::join::JoinPlan;
use crate::repr::{ColumnType, DiffRow, RelationType};
Expand Down Expand Up @@ -61,10 +61,13 @@ impl TypedPlan {
}

/// project the plan to the given expressions
pub fn projection(self, exprs: Vec<(ScalarExpr, ColumnType)>) -> Result<Self, Error> {
pub fn projection(self, exprs: Vec<TypedExpr>) -> Result<Self, Error> {
let input_arity = self.typ.column_types.len();
let output_arity = exprs.len();
let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs.into_iter().unzip();
let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs
.into_iter()
.map(|TypedExpr { expr, typ }| (expr, typ))
.unzip();
let mfp = MapFilterProject::new(input_arity)
.map(exprs)?
.project(input_arity..input_arity + output_arity)?;
Expand All @@ -87,18 +90,19 @@ impl TypedPlan {
}

/// Add a new filter to the plan, will filter out the records that do not satisfy the filter
pub fn filter(self, filter: (ScalarExpr, ColumnType)) -> Result<Self, Error> {
pub fn filter(self, filter: TypedExpr) -> Result<Self, Error> {
let plan = match self.plan {
Plan::Mfp {
input,
mfp: old_mfp,
} => Plan::Mfp {
input,
mfp: old_mfp.filter(vec![filter.0])?,
mfp: old_mfp.filter(vec![filter.expr])?,
},
_ => Plan::Mfp {
input: Box::new(self.plan),
mfp: MapFilterProject::new(self.typ.column_types.len()).filter(vec![filter.0])?,
mfp: MapFilterProject::new(self.typ.column_types.len())
.filter(vec![filter.expr])?,
},
};
Ok(TypedPlan {
Expand Down
Loading

0 comments on commit db329f6

Please sign in to comment.