Skip to content

Commit

Permalink
apply timestamp simplify rule before type coercion
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Nov 14, 2024
1 parent 7bd2eb7 commit 2b49890
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 35 deletions.
43 changes: 13 additions & 30 deletions wren-core/core/src/logical_plan/optimize/simplify_timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
* under the License.
*/
use datafusion::arrow::datatypes::{DataType, TimeUnit};
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion::common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRewriter,
};
use datafusion::common::ScalarValue::{
TimestampMicrosecond, TimestampMillisecond, TimestampSecond,
};
use datafusion::common::{DFSchema, DFSchemaRef, Result, ScalarValue};
use datafusion::config::ConfigOptions;
use datafusion::execution::context::ExecutionProps;
use datafusion::logical_expr::expr_rewriter::NamePreserver;
use datafusion::logical_expr::simplify::SimplifyContext;
use datafusion::logical_expr::utils::merge_schema;
use datafusion::logical_expr::{cast, Cast, LogicalPlan, TryCast};
use datafusion::optimizer::optimizer::ApplyOrder;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::optimizer::{OptimizerConfig, OptimizerRule};
use datafusion::optimizer::AnalyzerRule;
use datafusion::prelude::Expr;
use datafusion::scalar::ScalarValue::TimestampNanosecond;
use std::sync::Arc;
Expand All @@ -46,37 +48,18 @@ impl TimestampSimplify {
}
}

impl OptimizerRule for TimestampSimplify {
fn name(&self) -> &str {
"simplify_cast_expressions"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
impl AnalyzerRule for TimestampSimplify {
fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
Self::analyze_internal(plan).data()
}

fn supports_rewrite(&self) -> bool {
true
}

/// if supports_owned returns true, the Optimizer calls
/// [`Self::rewrite`] instead of [`Self::try_optimize`]
fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let mut execution_props = ExecutionProps::new();
execution_props.query_execution_start_time = config.query_execution_start_time();
Self::optimize_internal(plan, &execution_props)
fn name(&self) -> &str {
"simplify_timestamp_expressions"
}
}

impl TimestampSimplify {
fn optimize_internal(
plan: LogicalPlan,
execution_props: &ExecutionProps,
) -> Result<Transformed<LogicalPlan>> {
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let schema = if !plan.inputs().is_empty() {
DFSchemaRef::new(merge_schema(&plan.inputs()))
} else if let LogicalPlan::TableScan(scan) = &plan {
Expand All @@ -97,8 +80,8 @@ impl TimestampSimplify {
} else {
Arc::new(DFSchema::empty())
};

let info = SimplifyContext::new(execution_props).with_schema(schema);
let execution_props = ExecutionProps::default();
let info = SimplifyContext::new(&execution_props).with_schema(schema);

// Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer)
// Just need to rewrite our own expressions
Expand Down
4 changes: 3 additions & 1 deletion wren-core/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ fn analyze_rule_for_unparsing(
Arc::new(InlineTableScan::new()),
// Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule].
Arc::new(ExpandWildcardRule::new()),
// TimestampSimplify should be placed before TypeCoercion because the simplified timestamp should
// be casted to the target type if needed
Arc::new(TimestampSimplify::new()),
// [Expr::Wildcard] should be expanded before [TypeCoercion]
Arc::new(TypeCoercion::new()),
// Disable it to avoid generate the alias name, `count(*)` because BigQuery doesn't allow
Expand Down Expand Up @@ -180,7 +183,6 @@ fn optimize_rule_for_unparsing() -> Vec<Arc<dyn OptimizerRule + Send + Sync>> {
Arc::new(SingleDistinctToGroupBy::new()),
// Disable SimplifyExpressions to avoid apply some function locally
// Arc::new(SimplifyExpressions::new()),
Arc::new(TimestampSimplify::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateGroupByConstant::new()),
Expand Down
50 changes: 46 additions & 4 deletions wren-core/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ mod test {
.build(),
)
.column(
ColumnBuilder::new("cast_timestamp", "timestamp")
ColumnBuilder::new("cast_timestamptz", "timestamptz")
.expression(r#"cast("出道時間" as timestamp with time zone)"#)
.build(),
)
Expand All @@ -976,7 +976,7 @@ mod test {
.build();

let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?);
let sql = r#"select count(*) from wren.test.artist where cast(cast_timestamp as timestamp) > timestamp '2011-01-01 21:00:00'"#;
let sql = r#"select count(*) from wren.test.artist where cast(cast_timestamptz as timestamp) > timestamp '2011-01-01 21:00:00'"#;
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
Expand All @@ -985,8 +985,8 @@ mod test {
)
.await?;
assert_eq!(actual,
"SELECT count(*) FROM (SELECT artist.cast_timestamp FROM (SELECT CAST(artist.\"出道時間\" AS TIMESTAMP WITH TIME ZONE) AS cast_timestamp \
FROM artist) AS artist) AS artist WHERE artist.cast_timestamp > CAST('2011-01-01 21:00:00' AS TIMESTAMP)");
"SELECT count(*) FROM (SELECT artist.cast_timestamptz FROM (SELECT CAST(artist.\"出道時間\" AS TIMESTAMP WITH TIME ZONE) AS cast_timestamptz \
FROM artist) AS artist) AS artist WHERE CAST(artist.cast_timestamptz AS TIMESTAMP) > CAST('2011-01-01 21:00:00' AS TIMESTAMP)");
Ok(())
}

Expand Down Expand Up @@ -1071,6 +1071,48 @@ mod test {
(SELECT timestamp_table.timestamp_col, timestamp_table.timestamptz_col FROM \
(SELECT timestamp_table.timestamp_col AS timestamp_col, timestamp_table.timestamptz_col AS timestamptz_col \
FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table");

let sql = r#"select timestamptz_col > cast('2011-01-01 18:00:00' as TIMESTAMP WITH TIME ZONE) from wren.test.timestamp_table"#;
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&[],
sql,
)
.await?;
// assert the simplified literal will be casted to the timestamp tz
assert_eq!(actual,
"SELECT timestamp_table.timestamptz_col > CAST(CAST('2011-01-01 18:00:00' AS TIMESTAMP) AS TIMESTAMP WITH TIME ZONE) \
FROM (SELECT timestamp_table.timestamptz_col FROM (SELECT timestamp_table.timestamptz_col AS timestamptz_col \
FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table");

let sql = r#"select timestamptz_col > '2011-01-01 18:00:00' from wren.test.timestamp_table"#;
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&[],
sql,
)
.await?;
// assert the string literal will be casted to the timestamp tz
assert_eq!(actual,
"SELECT timestamp_table.timestamptz_col > CAST('2011-01-01 18:00:00' AS TIMESTAMP WITH TIME ZONE) \
FROM (SELECT timestamp_table.timestamptz_col FROM (SELECT timestamp_table.timestamptz_col AS timestamptz_col \
FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table");

let sql = r#"select timestamp_col > cast('2011-01-01 18:00:00' as TIMESTAMP WITH TIME ZONE) from wren.test.timestamp_table"#;
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&[],
sql,
)
.await?;
// assert the simplified literal won't be casted to the timestamp tz
assert_eq!(actual,
"SELECT timestamp_table.timestamp_col > CAST('2011-01-01 18:00:00' AS TIMESTAMP) FROM \
(SELECT timestamp_table.timestamp_col FROM (SELECT timestamp_table.timestamp_col AS timestamp_col \
FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table");
}
Ok(())
}
Expand Down

0 comments on commit 2b49890

Please sign in to comment.