Skip to content

Commit

Permalink
Remove redundant Aggregate when DISTINCT & GROUP BY are in the sa…
Browse files Browse the repository at this point in the history
…me query (#11781)

* Delete docs.yaml

* initialize eliminate_aggregate.rs rule

* remove redundant prints

* Add multiple group by expression handling.

* rename eliminate_aggregate.rs as eliminate_distinct.rs

implement as rewrite function

* remove logic for distinct on since group by statement must exist in projection

* format code

* add eliminate_distinct rule to tests

* simplify function
add additional tests for not removing cases

* fix child issue

* format

* fix docs

* remove eliminate_distinct rule and make it a part of replace_distinct_aggregate

* Update datafusion/common/src/functional_dependencies.rs

Co-authored-by: Mehmet Ozan Kabak <[email protected]>

* add comment and fix variable call

* fix test cases as optimized plan

* format code

* simplify comments

Co-authored-by: Mehmet Ozan Kabak <[email protected]>

* do not replace redundant distincts with aggregate

---------

Co-authored-by: metesynnada <[email protected]>
Co-authored-by: Mustafa Akur <[email protected]>
Co-authored-by: Mustafa Akur <[email protected]>
Co-authored-by: Mert Akkaya <[email protected]>
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
  • Loading branch information
6 people authored Aug 4, 2024
1 parent a4d41d6 commit c8e5996
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 30 deletions.
39 changes: 24 additions & 15 deletions datafusion/common/src/functional_dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,22 +524,31 @@ pub fn aggregate_functional_dependencies(
}
}

// If we have a single GROUP BY key, we can guarantee uniqueness after
// When we have a GROUP BY key, we can guarantee uniqueness after
// aggregation:
if group_by_expr_names.len() == 1 {
// If `source_indices` contain 0, delete this functional dependency
// as it will be added anyway with mode `Dependency::Single`:
aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0));
// Add a new functional dependency associated with the whole table:
aggregate_func_dependencies.push(
// Use nullable property of the group by expression
FunctionalDependence::new(
vec![0],
target_indices,
aggr_fields[0].is_nullable(),
)
.with_mode(Dependency::Single),
);
if !group_by_expr_names.is_empty() {
let count = group_by_expr_names.len();
let source_indices = (0..count).collect::<Vec<_>>();
let nullable = source_indices
.iter()
.any(|idx| aggr_fields[*idx].is_nullable());
// If GROUP BY expressions do not already act as a determinant:
if !aggregate_func_dependencies.iter().any(|item| {
// If `item.source_indices` is a subset of GROUP BY expressions, we shouldn't add
// them since `item.source_indices` defines this relation already.

// The following simple comparison is working well because
// GROUP BY expressions come here as a prefix.
item.source_indices.iter().all(|idx| idx < &count)
}) {
// Add a new functional dependency associated with the whole table:
// Use nullable property of the GROUP BY expression:
aggregate_func_dependencies.push(
// Use nullable property of the GROUP BY expression:
FunctionalDependence::new(source_indices, target_indices, nullable)
.with_mode(Dependency::Single),
);
}
}
FunctionalDependencies::new(aggregate_func_dependencies)
}
Expand Down
90 changes: 90 additions & 0 deletions datafusion/optimizer/src/replace_distinct_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,21 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
match plan {
LogicalPlan::Distinct(Distinct::All(input)) => {
let group_expr = expand_wildcard(input.schema(), &input, None)?;

let field_count = input.schema().fields().len();
for dep in input.schema().functional_dependencies().iter() {
// If distinct is exactly the same with a previous GROUP BY, we can
// simply remove it:
if dep.source_indices[..field_count]
.iter()
.enumerate()
.all(|(idx, f_idx)| idx == *f_idx)
{
return Ok(Transformed::yes(input.as_ref().clone()));
}
}

// Replace with aggregation:
let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new(
input,
group_expr,
Expand Down Expand Up @@ -165,3 +180,78 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
Some(BottomUp)
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
use crate::test::*;

use datafusion_common::Result;
use datafusion_expr::{
col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan,
};
use datafusion_functions_aggregate::sum::sum;

fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(
Arc::new(ReplaceDistinctWithAggregate::new()),
plan.clone(),
expected,
)
}

#[test]
fn eliminate_redundant_distinct_simple() -> Result<()> {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], Vec::<Expr>::new())?
.project(vec![col("c")])?
.distinct()?
.build()?;

let expected = "Projection: test.c\n Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_redundant_distinct_pair() -> Result<()> {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a"), col("b")], Vec::<Expr>::new())?
.project(vec![col("a"), col("b")])?
.distinct()?
.build()?;

let expected =
"Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn do_not_eliminate_distinct() -> Result<()> {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.distinct()?
.build()?;

let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n TableScan: test";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn do_not_eliminate_distinct_with_aggr() -> Result<()> {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a"), col("b"), col("c")], vec![sum(col("c"))])?
.project(vec![col("a"), col("b")])?
.distinct()?
.build()?;

let expected =
"Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n TableScan: test";
assert_optimized_plan_equal(&plan, expected)
}
}
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use hashbrown::HashSet;
/// single distinct to group by optimizer rule
/// ```text
/// Before:
/// SELECT a, count(DINSTINCT b), sum(c)
/// SELECT a, count(DISTINCT b), sum(c)
/// FROM t
/// GROUP BY a
///
Expand Down
18 changes: 4 additions & 14 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4536,19 +4536,14 @@ EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5;
logical_plan
01)Limit: skip=0, fetch=5
02)--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]]
03)----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]]
04)------TableScan: aggregate_test_100 projection=[c3]
03)----TableScan: aggregate_test_100 projection=[c3]
physical_plan
01)GlobalLimitExec: skip=0, fetch=5
02)--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5]
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
06)----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5]
07)------------CoalescePartitionsExec
08)--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5]
09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true
06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true

query I
SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5;
Expand Down Expand Up @@ -4699,19 +4694,14 @@ EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5;
logical_plan
01)Limit: skip=0, fetch=5
02)--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]]
03)----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]]
04)------TableScan: aggregate_test_100 projection=[c3]
03)----TableScan: aggregate_test_100 projection=[c3]
physical_plan
01)GlobalLimitExec: skip=0, fetch=5
02)--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[]
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
06)----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[]
07)------------CoalescePartitionsExec
08)--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[]
09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true
06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true

statement ok
set datafusion.optimizer.enable_distinct_aggregation_soft_limit = true;
Expand Down

0 comments on commit c8e5996

Please sign in to comment.