Skip to content

Commit

Permalink
fix: sum ret correct type
Browse files Browse the repository at this point in the history
  • Loading branch information
discord9 committed May 16, 2024
1 parent c045eec commit 0a6fd04
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 42 deletions.
50 changes: 35 additions & 15 deletions src/flow/src/expr/relation/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,26 +139,46 @@ macro_rules! generate_signature {
($value:ident,
{ $($user_arm:tt)* },
[ $(
$auto_arm:ident=>($con_type:ident,$generic:ident)
$auto_arm:ident=>($($arg:ident),*)
),*
]
) => {
match $value {
$($user_arm)*,
$(
Self::$auto_arm => Signature {
input: smallvec![
ConcreteDataType::$con_type(),
ConcreteDataType::$con_type(),
],
output: ConcreteDataType::$con_type(),
generic_fn: GenericFn::$generic,
},
Self::$auto_arm => gen_one_siginature!($($arg),*),
)*
}
};
}

/// Generate one match arm with optional arguments
macro_rules! gen_one_siginature {
(
$con_type:ident, $generic:ident
) => {
Signature {
input: smallvec![
ConcreteDataType::$con_type(),
ConcreteDataType::$con_type(),
],
output: ConcreteDataType::$con_type(),
generic_fn: GenericFn::$generic,
}
};
(
$in_type:ident, $out_type:ident, $generic:ident
) => {
Signature {
input: smallvec![
ConcreteDataType::$in_type()
],
output: ConcreteDataType::$out_type(),
generic_fn: GenericFn::$generic,
}
};
}

static SPECIALIZATION: OnceLock<HashMap<(GenericFn, ConcreteDataType), AggregateFunc>> =
OnceLock::new();

Expand Down Expand Up @@ -267,12 +287,12 @@ impl AggregateFunc {
MinTime => (time_second_datatype, Min),
MinDuration => (duration_second_datatype, Min),
MinInterval => (interval_year_month_datatype, Min),
SumInt16 => (int16_datatype, Sum),
SumInt32 => (int32_datatype, Sum),
SumInt64 => (int64_datatype, Sum),
SumUInt16 => (uint16_datatype, Sum),
SumUInt32 => (uint32_datatype, Sum),
SumUInt64 => (uint64_datatype, Sum),
SumInt16 => (int16_datatype, int64_datatype, Sum),
SumInt32 => (int32_datatype, int64_datatype, Sum),
SumInt64 => (int64_datatype, int64_datatype, Sum),
SumUInt16 => (uint16_datatype, uint64_datatype, Sum),
SumUInt32 => (uint32_datatype, uint64_datatype, Sum),
SumUInt64 => (uint64_datatype, uint64_datatype, Sum),
SumFloat32 => (float32_datatype, Sum),
SumFloat64 => (float64_datatype, Sum),
Any => (boolean_datatype, Any),
Expand Down
56 changes: 29 additions & 27 deletions src/flow/src/transform/aggr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,21 +210,23 @@ impl AggregateExpr {
expr: arg.expr.clone(),
distinct: false,
};
let sum_out_type = sum.func.signature().output.clone();
let count = AggregateExpr {
func: AggregateFunc::Count,
expr: arg.expr.clone(),
distinct: false,
};
let count_out_type = count.func.signature().output.clone();
let avg_output = ScalarExpr::Column(0).call_binary(
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(arg_type.clone())),
BinaryFunc::div(arg_type.clone())?,
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(sum_out_type.clone())),
BinaryFunc::div(sum_out_type.clone())?,
);
// make sure we wouldn't divide by zero
let zero = ScalarExpr::literal(arg_type.default_value(), arg_type.clone());
let zero = ScalarExpr::literal(count_out_type.default_value(), count_out_type.clone());
let non_zero = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::NotEq)),
then: Box::new(avg_output),
els: Box::new(ScalarExpr::literal(Value::Null, arg_type.clone())),
els: Box::new(ScalarExpr::literal(Value::Null, sum_out_type.clone())),
};
let ret_aggr_exprs = vec![sum, count];
let ret_mfp = Some(non_zero);
Expand Down Expand Up @@ -435,19 +437,19 @@ mod test {
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(2).call_binary(
ScalarExpr::Literal(Value::from(0u32), CDT::uint32_datatype()),
ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint32_datatype())),
BinaryFunc::DivUInt32,
ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
BinaryFunc::DivUInt64,
)),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint32_datatype())),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
};
let expected = TypedPlan {
typ: RelationType::new(vec![
ColumnType::new(CDT::uint32_datatype(), true),
ColumnType::new(CDT::uint32_datatype(), false),
ColumnType::new(CDT::uint64_datatype(), true), // sum(number) -> u64
ColumnType::new(CDT::uint32_datatype(), false), // number
]),
plan: Plan::Mfp {
input: Box::new(
Expand Down Expand Up @@ -484,7 +486,7 @@ mod test {
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false), // key: number
ColumnType::new(ConcreteDataType::uint32_datatype(), true), // sum
ColumnType::new(ConcreteDataType::uint64_datatype(), true), // sum
ColumnType::new(ConcreteDataType::int64_datatype(), true), // count
])
.with_key(vec![0]),
Expand Down Expand Up @@ -513,10 +515,7 @@ mod test {

let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let typ = RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), true),
ColumnType::new(ConcreteDataType::int64_datatype(), true),
]);

let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::SumUInt32,
Expand All @@ -531,17 +530,17 @@ mod test {
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Literal(Value::from(0u32), CDT::uint32_datatype()),
ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(ScalarExpr::Column(0).call_binary(
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint32_datatype())),
BinaryFunc::DivUInt32,
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
BinaryFunc::DivUInt64,
)),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint32_datatype())),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
Expand Down Expand Up @@ -572,7 +571,10 @@ mod test {
distinct_aggrs: vec![],
}),
}
.with_types(typ),
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint64_datatype(), true),
ColumnType::new(ConcreteDataType::int64_datatype(), true),
])),
),
mfp: MapFilterProject::new(2)
.map(vec![
Expand All @@ -597,7 +599,7 @@ mod test {
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let typ = RelationType::new(vec![ColumnType::new(
ConcreteDataType::uint32_datatype(),
ConcreteDataType::uint64_datatype(),
true,
)]);
let aggr_expr = AggregateExpr {
Expand All @@ -606,7 +608,7 @@ mod test {
distinct: false,
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
Expand Down Expand Up @@ -662,7 +664,7 @@ mod test {
};
let expected = TypedPlan {
typ: RelationType::new(vec![
ColumnType::new(CDT::uint32_datatype(), true), // col sum(number)
ColumnType::new(CDT::uint64_datatype(), true), // col sum(number)
ColumnType::new(CDT::uint32_datatype(), false), // col number
]),
plan: Plan::Mfp {
Expand Down Expand Up @@ -697,7 +699,7 @@ mod test {
.with_types(
RelationType::new(vec![
ColumnType::new(CDT::uint32_datatype(), false), // col number
ColumnType::new(CDT::uint32_datatype(), true), // col sum(number)
ColumnType::new(CDT::uint64_datatype(), true), // col sum(number)
])
.with_key(vec![0]),
),
Expand Down Expand Up @@ -732,7 +734,7 @@ mod test {
distinct: false,
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
Expand Down Expand Up @@ -764,7 +766,7 @@ mod test {
}),
}
.with_types(RelationType::new(vec![ColumnType::new(
CDT::uint32_datatype(),
CDT::uint64_datatype(),
true,
)])),
),
Expand Down

0 comments on commit 0a6fd04

Please sign in to comment.