Skip to content

Commit

Permalink
Make Logical Plans more readable by removing extra aliases (apache#10832
Browse files Browse the repository at this point in the history
)

* logical plan: remove unnecessary aliases

* revert EnterMark

* fix docs and benchmarks

* revert id_array change

* add alias counter

* fix alias counter bug

* fix slt test

* fix benchmark results

* revert alias/unalias changes

* remove TODO

* minor fix

* fix benchmark
  • Loading branch information
MohamedAbdeen21 authored and findepi committed Jul 16, 2024
1 parent 12941e4 commit df015d6
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 30 deletions.
73 changes: 54 additions & 19 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl CommonSubexprEliminate {
fn rewrite_exprs_list(
&self,
exprs_list: &[&[Expr]],
arrays_list: &[&[Vec<(usize, String)>]],
arrays_list: &[&[IdArray]],
expr_stats: &ExprStats,
common_exprs: &mut CommonExprs,
) -> Result<Vec<Vec<Expr>>> {
Expand Down Expand Up @@ -159,7 +159,7 @@ impl CommonSubexprEliminate {
fn rewrite_expr(
&self,
exprs_list: &[&[Expr]],
arrays_list: &[&[Vec<(usize, String)>]],
arrays_list: &[&[IdArray]],
input: &LogicalPlan,
expr_stats: &ExprStats,
config: &dyn OptimizerConfig,
Expand Down Expand Up @@ -480,7 +480,7 @@ fn to_arrays(
input_schema: DFSchemaRef,
expr_stats: &mut ExprStats,
expr_mask: ExprMask,
) -> Result<Vec<Vec<(usize, String)>>> {
) -> Result<Vec<IdArray>> {
expr.iter()
.map(|e| {
let mut id_array = vec![];
Expand Down Expand Up @@ -739,7 +739,7 @@ fn expr_identifier(expr: &Expr, sub_expr_identifier: Identifier) -> Identifier {
fn expr_to_identifier(
expr: &Expr,
expr_stats: &mut ExprStats,
id_array: &mut Vec<(usize, Identifier)>,
id_array: &mut IdArray,
input_schema: DFSchemaRef,
expr_mask: ExprMask,
) -> Result<()> {
Expand Down Expand Up @@ -769,15 +769,28 @@ struct CommonSubexprRewriter<'a> {
common_exprs: &'a mut CommonExprs,
// preorder index, starts from 0.
down_index: usize,
// how many aliases have we seen so far
alias_counter: usize,
}

impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
type Node = Expr;

fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
if matches!(expr, Expr::Alias(_)) {
self.alias_counter -= 1
}
Ok(Transformed::no(expr))
}

fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
if matches!(expr, Expr::Alias(_)) {
self.alias_counter += 1;
}

if expr.short_circuits() || expr.is_volatile()? {
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}
Expand All @@ -801,15 +814,16 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {

let expr_name = expr.display_name()?;
self.common_exprs.insert(expr_id.clone(), expr);
// Alias this `Column` expr to it original "expr name",
// `projection_push_down` optimizer use "expr name" to eliminate useless
// projections.
// TODO: do we really need to alias here?
Ok(Transformed::new(
col(expr_id).alias(expr_name),
true,
TreeNodeRecursion::Jump,
))

// alias the expressions without an `Alias` ancestor node
let rewritten = if self.alias_counter > 0 {
col(expr_id)
} else {
self.alias_counter += 1;
col(expr_id).alias(expr_name)
};

Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump))
} else {
Ok(Transformed::no(expr))
}
Expand All @@ -829,6 +843,7 @@ fn replace_common_expr(
id_array,
common_exprs,
down_index: 0,
alias_counter: 0,
})
.data()
}
Expand Down Expand Up @@ -962,6 +977,26 @@ mod test {
Ok(())
}

#[test]
fn nested_aliases() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
(col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
col("a") + col("b"),
])?
.build()?;

let expected = "Projection: {test.a + test.b|{test.b}|{test.a}} - test.c AS alias1 * {test.a + test.b|{test.b}|{test.a}} AS test.a + test.b, {test.a + test.b|{test.b}|{test.a}} AS test.a + test.b\
\n Projection: test.a + test.b AS {test.a + test.b|{test.b}|{test.a}}, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);

Ok(())
}

#[test]
fn aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down Expand Up @@ -1006,7 +1041,7 @@ mod test {
)?
.build()?;

let expected = "Projection: {AVG(test.a)|{test.a}} AS AVG(test.a) AS col1, {AVG(test.a)|{test.a}} AS AVG(test.a) AS col2, col3, {AVG(test.c)} AS AVG(test.c), {my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col4, {my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col5, col6, {my_agg(test.c)} AS my_agg(test.c)\
let expected = "Projection: {AVG(test.a)|{test.a}} AS col1, {AVG(test.a)|{test.a}} AS col2, col3, {AVG(test.c)} AS AVG(test.c), {my_agg(test.a)|{test.a}} AS col4, {my_agg(test.a)|{test.a}} AS col5, col6, {my_agg(test.c)} AS my_agg(test.c)\
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS {AVG(test.a)|{test.a}}, my_agg(test.a) AS {my_agg(test.a)|{test.a}}, AVG(test.b) AS col3, AVG(test.c) AS {AVG(test.c)}, my_agg(test.b) AS col6, my_agg(test.c) AS {my_agg(test.c)}]]\
\n TableScan: test";

Expand Down Expand Up @@ -1042,7 +1077,7 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test";
let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);

Expand All @@ -1057,7 +1092,7 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\
let expected = "Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col2]]\
\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\
\n TableScan: test";

Expand All @@ -1078,8 +1113,8 @@ mod test {
)?
.build()?;

let expected = "Projection: UInt32(1) + test.a, UInt32(1) + {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a) AS col1, UInt32(1) - {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a) AS col2, {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a), UInt32(1) + {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a) AS col3, UInt32(1) - {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a) AS col4, {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a)\
\n Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}}]]\
let expected = "Projection: UInt32(1) + test.a, UInt32(1) + {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col1, UInt32(1) - {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col2, {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)} AS AVG(UInt32(1) + test.a), UInt32(1) + {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col3, UInt32(1) - {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col4, {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)} AS my_agg(UInt32(1) + test.a)\
\n Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}, AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)}]]\
\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\
\n TableScan: test";

Expand Down Expand Up @@ -1126,7 +1161,7 @@ mod test {
])?
.build()?;

let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS Int32(1) + test.a AS first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS Int32(1) + test.a AS second\
let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS second\
\n Projection: Int32(1) + test.a AS {Int32(1) + test.a|{test.a}|{Int32(1)}}, test.a, test.b, test.c\
\n TableScan: test";

Expand Down
18 changes: 9 additions & 9 deletions datafusion/sqllogictest/test_files/group_by.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4538,7 +4538,7 @@ CREATE EXTERNAL TABLE timestamp_table (
c2 INT,
)
STORED AS CSV
LOCATION 'test_files/scratch/group_by/timestamp_table'
LOCATION 'test_files/scratch/group_by/timestamp_table'
OPTIONS ('format.has_header' 'true');

# Group By using date_trunc
Expand Down Expand Up @@ -4611,7 +4611,7 @@ DROP TABLE timestamp_table;

# Table with an int column and Dict<Int8> column:
statement ok
CREATE TABLE int8_dict AS VALUES
CREATE TABLE int8_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int8, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int8, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int8, Utf8)')),
Expand Down Expand Up @@ -4649,7 +4649,7 @@ DROP TABLE int8_dict;

# Table with an int column and Dict<Int16> column:
statement ok
CREATE TABLE int16_dict AS VALUES
CREATE TABLE int16_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int16, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int16, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int16, Utf8)')),
Expand Down Expand Up @@ -4687,7 +4687,7 @@ DROP TABLE int16_dict;

# Table with an int column and Dict<Int32> column:
statement ok
CREATE TABLE int32_dict AS VALUES
CREATE TABLE int32_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int32, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int32, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int32, Utf8)')),
Expand Down Expand Up @@ -4725,7 +4725,7 @@ DROP TABLE int32_dict;

# Table with an int column and Dict<Int64> column:
statement ok
CREATE TABLE int64_dict AS VALUES
CREATE TABLE int64_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int64, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int64, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int64, Utf8)')),
Expand Down Expand Up @@ -4763,7 +4763,7 @@ DROP TABLE int64_dict;

# Table with an int column and Dict<UInt8> column:
statement ok
CREATE TABLE uint8_dict AS VALUES
CREATE TABLE uint8_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt8, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt8, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt8, Utf8)')),
Expand Down Expand Up @@ -4801,7 +4801,7 @@ DROP TABLE uint8_dict;

# Table with an int column and Dict<UInt16> column:
statement ok
CREATE TABLE uint16_dict AS VALUES
CREATE TABLE uint16_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt16, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt16, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt16, Utf8)')),
Expand Down Expand Up @@ -4839,7 +4839,7 @@ DROP TABLE uint16_dict;

# Table with an int column and Dict<UInt32> column:
statement ok
CREATE TABLE uint32_dict AS VALUES
CREATE TABLE uint32_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt32, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt32, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt32, Utf8)')),
Expand Down Expand Up @@ -4877,7 +4877,7 @@ DROP TABLE uint32_dict;

# Table with an int column and Dict<UInt64> column:
statement ok
CREATE TABLE uint64_dict AS VALUES
CREATE TABLE uint64_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt64, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt64, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt64, Utf8)')),
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/tpch/q1.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ explain select
logical_plan
01)Sort: lineitem.l_returnflag ASC NULLS LAST, lineitem.l_linestatus ASC NULLS LAST
02)--Projection: lineitem.l_returnflag, lineitem.l_linestatus, sum(lineitem.l_quantity) AS sum_qty, sum(lineitem.l_extendedprice) AS sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(*) AS count_order
03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(Int64(1)) AS COUNT(*)]]
03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}}) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(Int64(1)) AS COUNT(*)]]
04)------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS {lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}}, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus
05)--------Filter: lineitem.l_shipdate <= Date32("1998-09-02")
06)----------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], partial_filters=[lineitem.l_shipdate <= Date32("1998-09-02")]
Expand Down Expand Up @@ -80,7 +80,7 @@ group by
l_linestatus
order by
l_returnflag,
l_linestatus;
l_linestatus;
----
A F 3774200 5320753880.69 5054096266.6828 5256751331.449234 25.537587 36002.123829 0.050144 147790
N F 95257 133737795.84 127132372.6512 132286291.229445 25.300664 35521.326916 0.049394 3765
Expand Down

0 comments on commit df015d6

Please sign in to comment.