Skip to content

Commit

Permalink
fix: tumble lose group expr
Browse files Browse the repository at this point in the history
  • Loading branch information
discord9 committed Aug 2, 2024
1 parent 822eeb8 commit 6abebca
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 25 deletions.
56 changes: 31 additions & 25 deletions src/flow/src/df_optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use common_error::ext::BoxedError;
use common_telemetry::debug;
use datafusion::config::ConfigOptions;
use datafusion::error::DataFusionError;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
Expand Down Expand Up @@ -64,7 +65,9 @@ pub async fn apply_df_optimizer(
Arc::new(TypeCoercion::new()),
]);
let plan = analyzer
.execute_and_check(plan, &cfg, |_p, _r| {})
.execute_and_check(plan, &cfg, |p, r| {
debug!("After apply rule {}, get plan: \n{:?}", r.name(), p);
})
.context(DatafusionSnafu {
context: "Fail to apply analyzer",
})?;
Expand Down Expand Up @@ -360,36 +363,39 @@ fn expand_tumble_analyzer(
if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
let mut new_group_expr = vec![];
let mut alias_to_expand = HashMap::new();
let mut encountered_tumble = false;
for expr in aggr.group_expr.iter() {
match expr {
datafusion_expr::Expr::ScalarFunction(func) => {
if func.name() == "tumble" {
let tumble_start = TumbleExpand::new("tumble_start");
let tumble_start = datafusion_expr::expr::ScalarFunction::new_udf(
Arc::new(tumble_start.into()),
func.args.clone(),
);
let tumble_start = datafusion_expr::Expr::ScalarFunction(tumble_start);
let start_col_name = tumble_start.name_for_alias()?;
new_group_expr.push(tumble_start);

let tumble_end = TumbleExpand::new("tumble_end");
let tumble_end = datafusion_expr::expr::ScalarFunction::new_udf(
Arc::new(tumble_end.into()),
func.args.clone(),
);
let tumble_end = datafusion_expr::Expr::ScalarFunction(tumble_end);
let end_col_name = tumble_end.name_for_alias()?;
new_group_expr.push(tumble_end);

alias_to_expand
.insert(expr.name_for_alias()?, (start_col_name, end_col_name));
}
datafusion_expr::Expr::ScalarFunction(func) if func.name() == "tumble" => {
encountered_tumble = true;

let tumble_start = TumbleExpand::new("tumble_start");
let tumble_start = datafusion_expr::expr::ScalarFunction::new_udf(
Arc::new(tumble_start.into()),
func.args.clone(),
);
let tumble_start = datafusion_expr::Expr::ScalarFunction(tumble_start);
let start_col_name = tumble_start.name_for_alias()?;
new_group_expr.push(tumble_start);

let tumble_end = TumbleExpand::new("tumble_end");
let tumble_end = datafusion_expr::expr::ScalarFunction::new_udf(
Arc::new(tumble_end.into()),
func.args.clone(),
);
let tumble_end = datafusion_expr::Expr::ScalarFunction(tumble_end);
let end_col_name = tumble_end.name_for_alias()?;
new_group_expr.push(tumble_end);

alias_to_expand
.insert(expr.name_for_alias()?, (start_col_name, end_col_name));
}
_ => new_group_expr.push(expr.clone()),
}
}

if !encountered_tumble {
return Ok(Transformed::no(plan));
}
let mut new_aggr = aggr.clone();
new_aggr.group_expr = new_group_expr;
let new_aggr = datafusion_expr::LogicalPlan::Aggregate(new_aggr).recompute_schema()?;
Expand Down
135 changes: 135 additions & 0 deletions src/flow/src/transform/aggr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1493,4 +1493,139 @@ mod test {
};
assert_eq!(flow_plan.unwrap(), expected);
}

#[tokio::test]
async fn test_cast_max_min() {
let engine = create_test_query_engine();
let sql = "SELECT CAST((max(number) - min(number)) AS FLOAT)/30.0, date_bin(INTERVAL '30 second', CAST(ts AS TimestampMillisecond)) as time_window from numbers_with_ts GROUP BY time_window";
let plan = sql_to_substrait(engine.clone(), sql).await;

let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;

let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::MaxUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
},
AggregateExpr {
func: AggregateFunc::MinUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
},
];
let expected = TypedPlan {
schema: RelationType::new(vec![
ColumnType::new(CDT::float64_datatype(), true),
ColumnType::new(CDT::timestamp_millisecond_datatype(), true),
])
.with_key(vec![1])
.into_named(vec![
Some(
"MAX(numbers_with_ts.number) - MIN(numbers_with_ts.number) / Float64(30)"
.to_string(),
),
Some("time_window".to_string()),
]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
ColumnType::new(ConcreteDataType::datetime_datatype(), false),
])
.into_named(vec![
Some("number".to_string()),
Some("ts".to_string()),
]),
)
.mfp(MapFilterProject::new(2).into_safe())
.unwrap(),
),

key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(2)
.map(vec![ScalarExpr::CallDf {
df_scalar_fn: DfScalarFunction::try_from_raw_fn(
RawDfScalarFn {
f: BytesMut::from(
b"\x08\x02\"I\x1aG\nE\x8a\x02?\x08\x03\x12+\n\x17interval-month-day-nano\x12\x10\0\xac#\xfc\x06\0\0\0\0\0\0\0\0\0\0\0\x1a\x06\x12\x04:\x02\x10\x02\x1a\x06\x12\x04:\x02\x10\x02\x98\x03\x03\"\n\x1a\x08\x12\x06\n\x04\x12\x02\x08\x01".as_ref(),
),
input_schema: RelationType::new(vec![ColumnType::new(
ConcreteDataType::interval_month_day_nano_datatype(),
true,
),ColumnType::new(
ConcreteDataType::timestamp_millisecond_datatype(),
true,
)])
.into_unnamed(),
extensions: FunctionExtensions {
anchor_to_name: BTreeMap::from([
(0, "subtract".to_string()),
(1, "divide".to_string()),
(2, "date_bin".to_string()),
(3, "max".to_string()),
(4, "min".to_string()),
]),
},
},
)
.await
.unwrap(),
exprs: vec![ScalarExpr::Literal(Value::Interval(Interval::from_month_day_nano(0, 0, 30000000000)), CDT::interval_month_day_nano_datatype()),
ScalarExpr::Column(1).cast(CDT::timestamp_millisecond_datatype())],
}])
.unwrap()
.project(vec![2])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(2)
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: aggr_exprs.clone(),
simple_aggrs: vec![AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1)],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(
ConcreteDataType::timestamp_millisecond_datatype(),
true,
), // time_window
ColumnType::new(ConcreteDataType::uint32_datatype(), true), // max
ColumnType::new(ConcreteDataType::uint32_datatype(), true), // min
])
.with_key(vec![0])
.into_unnamed(),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
ScalarExpr::Column(1)
.call_binary(ScalarExpr::Column(2), BinaryFunc::SubUInt32)
.cast(CDT::float32_datatype())
.cast(CDT::float64_datatype())
.call_binary(
ScalarExpr::Literal(Value::from(30.0f64), CDT::float64_datatype()),
BinaryFunc::DivFloat64,
),
ScalarExpr::Column(0),
])
.unwrap()
.project(vec![3, 4])
.unwrap(),
},
};

assert_eq!(flow_plan.unwrap(), expected);
}
}

0 comments on commit 6abebca

Please sign in to comment.