Skip to content

Commit

Permalink
feat: support COUNT() (#11229)
Browse files Browse the repository at this point in the history
* feat: add count empty rewrite

* feat: make count support zero args

* docs: add apache license

* tests: make count() valid

* tests: more tests

* refactor: sketch `AggregateFunctionPlanner`

* refactor: cleanup `AggregateFunctionPlanner`

* feat: add back rule

* Revert "feat: add back rule"

This reverts commit 2c4fc0a.

* Revert "refactor: cleanup `AggregateFunctionPlanner`"

This reverts commit 4550dbd.

* Revert "refactor: sketch `AggregateFunctionPlanner`"

This reverts commit 658671e.

* Apply suggestions from code review

Co-authored-by: Andrew Lamb <[email protected]>

* refactor: PR feedback

* style: fix indent

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
tshauck and alamb authored Jul 17, 2024
1 parent c95556d commit 0021356
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 9 deletions.
8 changes: 6 additions & 2 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
EmitTo, GroupsAccumulator, Signature, Volatility,
};
use datafusion_expr::{Expr, ReversedUDAF};
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
use datafusion_physical_expr_common::{
aggregate::count_distinct::{
Expand Down Expand Up @@ -95,7 +95,11 @@ impl Default for Count {
impl Count {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
signature: Signature::one_of(
// TypeSignature::Any(0) is required to handle `Count()` with no args
vec![TypeSignature::VariadicAny, TypeSignature::Any(0)],
Volatility::Immutable,
),
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct CountWildcardRule {}

impl CountWildcardRule {
pub fn new() -> Self {
CountWildcardRule {}
Self {}
}
}

Expand All @@ -59,14 +59,14 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
func_def: AggregateFunctionDefinition::UDF(udf),
args,
..
} if udf.name() == "count" && args.len() == 1 && is_wildcard(&args[0]))
} if udf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
}

fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
let args = &window_function.args;
matches!(window_function.fun,
WindowFunctionDefinition::AggregateUDF(ref udaf)
if udaf.name() == "count" && args.len() == 1 && is_wildcard(&args[0]))
if udaf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
Expand Down
105 changes: 105 additions & 0 deletions datafusion/sqllogictest/test_files/count_star_rule.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

statement ok
CREATE TABLE t1 (a INTEGER, b INTEGER, c INTEGER);

statement ok
INSERT INTO t1 VALUES
(1, 2, 3),
(1, 5, 6),
(2, 3, 5);

statement ok
CREATE TABLE t2 (a INTEGER, b INTEGER, c INTEGER);

query TT
EXPLAIN SELECT COUNT() FROM (SELECT 1 AS a, 2 AS b) AS t;
----
logical_plan
01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count()]]
02)--SubqueryAlias: t
03)----EmptyRelation
physical_plan
01)ProjectionExec: expr=[1 as count()]
02)--PlaceholderRowExec

query TT
EXPLAIN SELECT t1.a, COUNT() FROM t1 GROUP BY t1.a;
----
logical_plan
01)Aggregate: groupBy=[[t1.a]], aggr=[[count(Int64(1)) AS count()]]
02)--TableScan: t1 projection=[a]
physical_plan
01)AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count()]
02)--CoalesceBatchesExec: target_batch_size=8192
03)----RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
05)--------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count()]
06)----------MemoryExec: partitions=1, partition_sizes=[1]

query TT
EXPLAIN SELECT t1.a, COUNT() AS cnt FROM t1 GROUP BY t1.a HAVING COUNT() > 0;
----
logical_plan
01)Projection: t1.a, count() AS cnt
02)--Filter: count() > Int64(0)
03)----Aggregate: groupBy=[[t1.a]], aggr=[[count(Int64(1)) AS count()]]
04)------TableScan: t1 projection=[a]
physical_plan
01)ProjectionExec: expr=[a@0 as a, count()@1 as cnt]
02)--CoalesceBatchesExec: target_batch_size=8192
03)----FilterExec: count()@1 > 0
04)------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count()]
05)--------CoalesceBatchesExec: target_batch_size=8192
06)----------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
08)--------------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count()]
09)----------------MemoryExec: partitions=1, partition_sizes=[1]

query II
SELECT t1.a, COUNT() AS cnt FROM t1 GROUP BY t1.a HAVING COUNT() > 1;
----
1 2

query TT
EXPLAIN SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1;
----
logical_plan
01)Projection: t1.a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count_a
02)--WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
03)----TableScan: t1 projection=[a]
physical_plan
01)ProjectionExec: expr=[a@0 as a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as count_a]
02)--WindowAggExec: wdw=[count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]
03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]
04)------CoalesceBatchesExec: target_batch_size=8192
05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1
06)----------MemoryExec: partitions=1, partition_sizes=[1]

query II
SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1 ORDER BY a;
----
1 2
1 2
2 1

statement ok
DROP TABLE t1;

statement ok
DROP TABLE t2;
4 changes: 0 additions & 4 deletions datafusion/sqllogictest/test_files/errors.slt
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,6 @@ SELECT power(1, 2, 3);
# Wrong window/aggregate function signature
#

# AggregateFunction with wrong number of arguments
query error
select count();

# AggregateFunction with wrong number of arguments
query error
select avg(c1, c12) from aggregate_test_100;
Expand Down

0 comments on commit 0021356

Please sign in to comment.