From 0a6fd042b82d2099daa2bc584b70c64e442445f9 Mon Sep 17 00:00:00 2001 From: discord9 Date: Thu, 16 May 2024 13:14:34 +0800 Subject: [PATCH] fix: sum ret correct type --- src/flow/src/expr/relation/func.rs | 50 ++++++++++++++++++-------- src/flow/src/transform/aggr.rs | 56 ++++++++++++++++-------------- 2 files changed, 64 insertions(+), 42 deletions(-) diff --git a/src/flow/src/expr/relation/func.rs b/src/flow/src/expr/relation/func.rs index f1f69e365477..190dab4ae1cb 100644 --- a/src/flow/src/expr/relation/func.rs +++ b/src/flow/src/expr/relation/func.rs @@ -139,26 +139,46 @@ macro_rules! generate_signature { ($value:ident, { $($user_arm:tt)* }, [ $( - $auto_arm:ident=>($con_type:ident,$generic:ident) + $auto_arm:ident=>($($arg:ident),*) ),* ] ) => { match $value { $($user_arm)*, $( - Self::$auto_arm => Signature { - input: smallvec![ - ConcreteDataType::$con_type(), - ConcreteDataType::$con_type(), - ], - output: ConcreteDataType::$con_type(), - generic_fn: GenericFn::$generic, - }, + Self::$auto_arm => gen_one_siginature!($($arg),*), )* } }; } +/// Generate one match arm with optional arguments +macro_rules! gen_one_siginature { + ( + $con_type:ident, $generic:ident + ) => { + Signature { + input: smallvec![ + ConcreteDataType::$con_type(), + ConcreteDataType::$con_type(), + ], + output: ConcreteDataType::$con_type(), + generic_fn: GenericFn::$generic, + } + }; + ( + $in_type:ident, $out_type:ident, $generic:ident + ) => { + Signature { + input: smallvec![ + ConcreteDataType::$in_type() + ], + output: ConcreteDataType::$out_type(), + generic_fn: GenericFn::$generic, + } + }; +} + static SPECIALIZATION: OnceLock> = OnceLock::new(); @@ -267,12 +287,12 @@ impl AggregateFunc { MinTime => (time_second_datatype, Min), MinDuration => (duration_second_datatype, Min), MinInterval => (interval_year_month_datatype, Min), - SumInt16 => (int16_datatype, Sum), - SumInt32 => (int32_datatype, Sum), - SumInt64 => (int64_datatype, Sum), - SumUInt16 => (uint16_datatype, Sum), - SumUInt32 => (uint32_datatype, Sum), - SumUInt64 => (uint64_datatype, Sum), + SumInt16 => (int16_datatype, int64_datatype, Sum), + SumInt32 => (int32_datatype, int64_datatype, Sum), + SumInt64 => (int64_datatype, int64_datatype, Sum), + SumUInt16 => (uint16_datatype, uint64_datatype, Sum), + SumUInt32 => (uint32_datatype, uint64_datatype, Sum), + SumUInt64 => (uint64_datatype, uint64_datatype, Sum), SumFloat32 => (float32_datatype, Sum), SumFloat64 => (float64_datatype, Sum), Any => (boolean_datatype, Any), diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index f81e0e908da4..3f3bf3fb7c9f 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -210,21 +210,23 @@ impl AggregateExpr { expr: arg.expr.clone(), distinct: false, }; + let sum_out_type = sum.func.signature().output.clone(); let count = AggregateExpr { func: AggregateFunc::Count, expr: arg.expr.clone(), distinct: false, }; + let count_out_type = count.func.signature().output.clone(); let avg_output = ScalarExpr::Column(0).call_binary( - ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(arg_type.clone())), - BinaryFunc::div(arg_type.clone())?, + ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(sum_out_type.clone())), + BinaryFunc::div(sum_out_type.clone())?, ); // make sure we wouldn't divide by zero - let zero = ScalarExpr::literal(arg_type.default_value(), arg_type.clone()); + let zero = ScalarExpr::literal(count_out_type.default_value(), count_out_type.clone()); let non_zero = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::NotEq)), then: Box::new(avg_output), - els: Box::new(ScalarExpr::literal(Value::Null, arg_type.clone())), + els: Box::new(ScalarExpr::literal(Value::Null, sum_out_type.clone())), }; let ret_aggr_exprs = vec![sum, count]; let ret_mfp = Some(non_zero); @@ -435,19 +437,19 @@ mod test { ]; let avg_expr = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(2).call_binary( - ScalarExpr::Literal(Value::from(0u32), CDT::uint32_datatype()), + ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()), BinaryFunc::NotEq, )), then: Box::new(ScalarExpr::Column(1).call_binary( - ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint32_datatype())), - BinaryFunc::DivUInt32, + ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), + BinaryFunc::DivUInt64, )), - els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint32_datatype())), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), }; let expected = TypedPlan { typ: RelationType::new(vec![ - ColumnType::new(CDT::uint32_datatype(), true), - ColumnType::new(CDT::uint32_datatype(), false), + ColumnType::new(CDT::uint64_datatype(), true), // sum(number) -> u64 + ColumnType::new(CDT::uint32_datatype(), false), // number ]), plan: Plan::Mfp { input: Box::new( @@ -484,7 +486,7 @@ mod test { .with_types( RelationType::new(vec![ ColumnType::new(ConcreteDataType::uint32_datatype(), false), // key: number - ColumnType::new(ConcreteDataType::uint32_datatype(), true), // sum + ColumnType::new(ConcreteDataType::uint64_datatype(), true), // sum ColumnType::new(ConcreteDataType::int64_datatype(), true), // count ]) .with_key(vec![0]), @@ -513,10 +515,7 @@ mod test { let mut ctx = create_test_ctx(); let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); - let typ = RelationType::new(vec![ - ColumnType::new(ConcreteDataType::uint32_datatype(), true), - ColumnType::new(ConcreteDataType::int64_datatype(), true), - ]); + let aggr_exprs = vec![ AggregateExpr { func: AggregateFunc::SumUInt32, @@ -531,17 +530,17 @@ mod test { ]; let avg_expr = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(1).call_binary( - ScalarExpr::Literal(Value::from(0u32), CDT::uint32_datatype()), + ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()), BinaryFunc::NotEq, )), then: Box::new(ScalarExpr::Column(0).call_binary( - ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint32_datatype())), - BinaryFunc::DivUInt32, + ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), + BinaryFunc::DivUInt64, )), - els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint32_datatype())), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), }; let expected = TypedPlan { - typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]), + typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { @@ -572,7 +571,10 @@ mod test { distinct_aggrs: vec![], }), } - .with_types(typ), + .with_types(RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::int64_datatype(), true), + ])), ), mfp: MapFilterProject::new(2) .map(vec![ @@ -597,7 +599,7 @@ mod test { let mut ctx = create_test_ctx(); let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); let typ = RelationType::new(vec![ColumnType::new( - ConcreteDataType::uint32_datatype(), + ConcreteDataType::uint64_datatype(), true, )]); let aggr_expr = AggregateExpr { @@ -606,7 +608,7 @@ mod test { distinct: false, }; let expected = TypedPlan { - typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]), + typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { @@ -662,7 +664,7 @@ mod test { }; let expected = TypedPlan { typ: RelationType::new(vec![ - ColumnType::new(CDT::uint32_datatype(), true), // col sum(number) + ColumnType::new(CDT::uint64_datatype(), true), // col sum(number) ColumnType::new(CDT::uint32_datatype(), false), // col number ]), plan: Plan::Mfp { @@ -697,7 +699,7 @@ mod test { .with_types( RelationType::new(vec![ ColumnType::new(CDT::uint32_datatype(), false), // col number - ColumnType::new(CDT::uint32_datatype(), true), // col sum(number) + ColumnType::new(CDT::uint64_datatype(), true), // col sum(number) ]) .with_key(vec![0]), ), @@ -732,7 +734,7 @@ mod test { distinct: false, }; let expected = TypedPlan { - typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]), + typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { @@ -764,7 +766,7 @@ mod test { }), } .with_types(RelationType::new(vec![ColumnType::new( - CDT::uint32_datatype(), + CDT::uint64_datatype(), true, )])), ),