From c045eec8c740e2d9f0eebfadaf8098e01303baf3 Mon Sep 17 00:00:00 2001 From: discord9 Date: Thu, 16 May 2024 11:45:45 +0800 Subject: [PATCH] fix(WIP): avg eval --- src/flow/src/compute/render.rs | 6 ++ src/flow/src/compute/render/reduce.rs | 102 +++++++++++++++++++++++++- src/flow/src/compute/types.rs | 3 + src/flow/src/expr/relation/func.rs | 14 ++-- src/flow/src/transform/aggr.rs | 6 +- 5 files changed, 121 insertions(+), 10 deletions(-) diff --git a/src/flow/src/compute/render.rs b/src/flow/src/compute/render.rs index 0476c8a6e5ac..bf298e86bc30 100644 --- a/src/flow/src/compute/render.rs +++ b/src/flow/src/compute/render.rs @@ -238,6 +238,12 @@ mod test { for now in time_range { state.set_current_ts(now); state.run_available_with_schedule(df); + if !state.get_err_collector().is_empty() { + panic!( + "Errors occur: {:?}", + state.get_err_collector().get_all_blocking() + ) + } assert!(state.get_err_collector().is_empty()); if let Some(expected) = expected.get(&now) { assert_eq!(*output.borrow(), *expected, "at ts={}", now); diff --git a/src/flow/src/compute/render/reduce.rs b/src/flow/src/compute/render/reduce.rs index 46b2dc196f00..da2bb11f4b42 100644 --- a/src/flow/src/compute/render/reduce.rs +++ b/src/flow/src/compute/render/reduce.rs @@ -729,15 +729,113 @@ mod test { use std::cell::RefCell; use std::rc::Rc; - use datatypes::data_type::ConcreteDataType; + use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT}; use hydroflow::scheduled::graph::Hydroflow; use super::*; use crate::compute::render::test::{get_output_handle, harness_test_ctx, run_and_check}; use crate::compute::state::DataflowState; - use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject}; + use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc}; use crate::repr::{ColumnType, RelationType}; + /// select avg(number) from number; + #[test] + fn test_avg_eval() { + let mut df = Hydroflow::new(); + let mut state = DataflowState::default(); + let mut ctx = harness_test_ctx(&mut df, &mut state); + + let rows = vec![ + (Row::new(vec![1u32.into()]), 1, 1), + (Row::new(vec![2u32.into()]), 1, 1), + (Row::new(vec![3u32.into()]), 1, 1), + (Row::new(vec![1u32.into()]), 1, 1), + (Row::new(vec![2u32.into()]), 1, 1), + (Row::new(vec![3u32.into()]), 1, 1), + ]; + let collection = ctx.render_constant(rows.clone()); + ctx.insert_global(GlobalId::User(1), collection); + + let aggr_exprs = vec![ + AggregateExpr { + func: AggregateFunc::SumUInt32, + expr: ScalarExpr::Column(0), + distinct: false, + }, + AggregateExpr { + func: AggregateFunc::Count, + expr: ScalarExpr::Column(0), + distinct: false, + }, + ]; + let avg_expr = ScalarExpr::If { + cond: Box::new(ScalarExpr::Column(1).call_binary( + ScalarExpr::Literal(Value::from(0u32), CDT::int64_datatype()), + BinaryFunc::NotEq, + )), + then: Box::new(ScalarExpr::Column(0).call_binary( + ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), + BinaryFunc::DivUInt64, + )), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), + }; + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]), + plan: Plan::Mfp { + input: Box::new( + Plan::Reduce { + input: Box::new( + Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(1)), + } + .with_types(RelationType::new(vec![ + ColumnType::new(ConcreteDataType::int64_datatype(), false), + ])), + ), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .project(vec![]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .project(vec![0]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: aggr_exprs.clone(), + simple_aggrs: vec![ + AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), + AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1), + ], + distinct_aggrs: vec![], + }), + } + .with_types(RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint32_datatype(), true), + ColumnType::new(ConcreteDataType::int64_datatype(), true), + ])), + ), + mfp: MapFilterProject::new(2) + .map(vec![ + avg_expr, + // TODO(discord9): optimize mfp so to remove indirect ref + ScalarExpr::Column(2), + ]) + .unwrap() + .project(vec![3]) + .unwrap(), + }, + }; + + let bundle = ctx.render_plan(expected).unwrap(); + + let output = get_output_handle(&mut ctx, bundle); + drop(ctx); + let expected = BTreeMap::from([(1, vec![(Row::new(vec![2u64.into()]), 1, 1)])]); + run_and_check(&mut state, &mut df, 1..2, expected, output); + } + /// SELECT DISTINCT col FROM table /// /// table schema: diff --git a/src/flow/src/compute/types.rs b/src/flow/src/compute/types.rs index fa8c7315cb4f..f2276ba755eb 100644 --- a/src/flow/src/compute/types.rs +++ b/src/flow/src/compute/types.rs @@ -153,6 +153,9 @@ pub struct ErrCollector { } impl ErrCollector { + pub fn get_all_blocking(&self) -> Vec { + self.inner.blocking_lock().drain(..).collect_vec() + } pub async fn get_all(&self) -> Vec { self.inner.lock().await.drain(..).collect_vec() } diff --git a/src/flow/src/expr/relation/func.rs b/src/flow/src/expr/relation/func.rs index 4506bf7a5507..f1f69e365477 100644 --- a/src/flow/src/expr/relation/func.rs +++ b/src/flow/src/expr/relation/func.rs @@ -136,11 +136,13 @@ impl AggregateFunc { /// Generate signature for each aggregate function macro_rules! generate_signature { - ($value:ident, { $($user_arm:tt)* }, - [ $( - $auto_arm:ident=>($con_type:ident,$generic:ident) - ),* - ]) => { + ($value:ident, + { $($user_arm:tt)* }, + [ $( + $auto_arm:ident=>($con_type:ident,$generic:ident) + ),* + ] + ) => { match $value { $($user_arm)*, $( @@ -223,6 +225,8 @@ impl AggregateFunc { /// all concrete datatypes with precision types will be returned with largest possible variant /// as a exception, count have a signature of `null -> i64`, but it's actually `anytype -> i64` + /// + /// TODO(discorcd9): fix signature for sum usign -> u64 sum signed -> i64 pub fn signature(&self) -> Signature { generate_signature!(self, { AggregateFunc::Count => Signature { diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index d4a370bb58be..f81e0e908da4 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -222,7 +222,7 @@ impl AggregateExpr { // make sure we wouldn't divide by zero let zero = ScalarExpr::literal(arg_type.default_value(), arg_type.clone()); let non_zero = ScalarExpr::If { - cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::Eq)), + cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::NotEq)), then: Box::new(avg_output), els: Box::new(ScalarExpr::literal(Value::Null, arg_type.clone())), }; @@ -436,7 +436,7 @@ mod test { let avg_expr = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(2).call_binary( ScalarExpr::Literal(Value::from(0u32), CDT::uint32_datatype()), - BinaryFunc::Eq, + BinaryFunc::NotEq, )), then: Box::new(ScalarExpr::Column(1).call_binary( ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint32_datatype())), @@ -532,7 +532,7 @@ mod test { let avg_expr = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(1).call_binary( ScalarExpr::Literal(Value::from(0u32), CDT::uint32_datatype()), - BinaryFunc::Eq, + BinaryFunc::NotEq, )), then: Box::new(ScalarExpr::Column(0).call_binary( ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint32_datatype())),