Skip to content

Commit

Permalink
fix(WIP): avg eval
Browse files Browse the repository at this point in the history
  • Loading branch information
discord9 committed May 16, 2024
1 parent 54c477c commit c045eec
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 10 deletions.
6 changes: 6 additions & 0 deletions src/flow/src/compute/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ mod test {
for now in time_range {
state.set_current_ts(now);
state.run_available_with_schedule(df);
if !state.get_err_collector().is_empty() {
panic!(
"Errors occur: {:?}",
state.get_err_collector().get_all_blocking()
)
}
assert!(state.get_err_collector().is_empty());
if let Some(expected) = expected.get(&now) {
assert_eq!(*output.borrow(), *expected, "at ts={}", now);
Expand Down
102 changes: 100 additions & 2 deletions src/flow/src/compute/render/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -729,15 +729,113 @@ mod test {
use std::cell::RefCell;
use std::rc::Rc;

use datatypes::data_type::ConcreteDataType;
use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT};
use hydroflow::scheduled::graph::Hydroflow;

use super::*;
use crate::compute::render::test::{get_output_handle, harness_test_ctx, run_and_check};
use crate::compute::state::DataflowState;
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject};
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc};
use crate::repr::{ColumnType, RelationType};

/// select avg(number) from number;
#[test]
fn test_avg_eval() {
let mut df = Hydroflow::new();
let mut state = DataflowState::default();
let mut ctx = harness_test_ctx(&mut df, &mut state);

let rows = vec![
(Row::new(vec![1u32.into()]), 1, 1),
(Row::new(vec![2u32.into()]), 1, 1),
(Row::new(vec![3u32.into()]), 1, 1),
(Row::new(vec![1u32.into()]), 1, 1),
(Row::new(vec![2u32.into()]), 1, 1),
(Row::new(vec![3u32.into()]), 1, 1),
];
let collection = ctx.render_constant(rows.clone());
ctx.insert_global(GlobalId::User(1), collection);

let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
},
AggregateExpr {
func: AggregateFunc::Count,
expr: ScalarExpr::Column(0),
distinct: false,
},
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Literal(Value::from(0u32), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(ScalarExpr::Column(0).call_binary(
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
BinaryFunc::DivUInt64,
)),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::int64_datatype(), false),
])),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.project(vec![])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.project(vec![0])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: aggr_exprs.clone(),
simple_aggrs: vec![
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1),
],
distinct_aggrs: vec![],
}),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), true),
ColumnType::new(ConcreteDataType::int64_datatype(), true),
])),
),
mfp: MapFilterProject::new(2)
.map(vec![
avg_expr,
// TODO(discord9): optimize mfp so to remove indirect ref
ScalarExpr::Column(2),
])
.unwrap()
.project(vec![3])
.unwrap(),
},
};

let bundle = ctx.render_plan(expected).unwrap();

let output = get_output_handle(&mut ctx, bundle);
drop(ctx);
let expected = BTreeMap::from([(1, vec![(Row::new(vec![2u64.into()]), 1, 1)])]);
run_and_check(&mut state, &mut df, 1..2, expected, output);
}

/// SELECT DISTINCT col FROM table
///
/// table schema:
Expand Down
3 changes: 3 additions & 0 deletions src/flow/src/compute/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ pub struct ErrCollector {
}

impl ErrCollector {
pub fn get_all_blocking(&self) -> Vec<EvalError> {
self.inner.blocking_lock().drain(..).collect_vec()
}
pub async fn get_all(&self) -> Vec<EvalError> {
self.inner.lock().await.drain(..).collect_vec()
}
Expand Down
14 changes: 9 additions & 5 deletions src/flow/src/expr/relation/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,13 @@ impl AggregateFunc {

/// Generate signature for each aggregate function
macro_rules! generate_signature {
($value:ident, { $($user_arm:tt)* },
[ $(
$auto_arm:ident=>($con_type:ident,$generic:ident)
),*
]) => {
($value:ident,
{ $($user_arm:tt)* },
[ $(
$auto_arm:ident=>($con_type:ident,$generic:ident)
),*
]
) => {
match $value {
$($user_arm)*,
$(
Expand Down Expand Up @@ -223,6 +225,8 @@ impl AggregateFunc {

/// all concrete datatypes with precision types will be returned with largest possible variant
/// as a exception, count have a signature of `null -> i64`, but it's actually `anytype -> i64`
///
/// TODO(discorcd9): fix signature for sum usign -> u64 sum signed -> i64
pub fn signature(&self) -> Signature {
generate_signature!(self, {
AggregateFunc::Count => Signature {
Expand Down
6 changes: 3 additions & 3 deletions src/flow/src/transform/aggr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl AggregateExpr {
// make sure we wouldn't divide by zero
let zero = ScalarExpr::literal(arg_type.default_value(), arg_type.clone());
let non_zero = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::Eq)),
cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::NotEq)),
then: Box::new(avg_output),
els: Box::new(ScalarExpr::literal(Value::Null, arg_type.clone())),
};
Expand Down Expand Up @@ -436,7 +436,7 @@ mod test {
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(2).call_binary(
ScalarExpr::Literal(Value::from(0u32), CDT::uint32_datatype()),
BinaryFunc::Eq,
BinaryFunc::NotEq,
)),
then: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint32_datatype())),
Expand Down Expand Up @@ -532,7 +532,7 @@ mod test {
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Literal(Value::from(0u32), CDT::uint32_datatype()),
BinaryFunc::Eq,
BinaryFunc::NotEq,
)),
then: Box::new(ScalarExpr::Column(0).call_binary(
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint32_datatype())),
Expand Down

0 comments on commit c045eec

Please sign in to comment.