From 9d12ebc4654b1f3c91612d40caab180334ee62a6 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Tue, 21 May 2024 11:56:34 -0500 Subject: [PATCH] refactor(sql): extract aggregate handling out into common utility class --- ibis/backends/bigquery/compiler.py | 8 --- ibis/backends/clickhouse/compiler.py | 21 +++--- ibis/backends/datafusion/compiler.py | 10 +-- ibis/backends/druid/compiler.py | 10 +-- ibis/backends/duckdb/compiler.py | 10 +-- ibis/backends/exasol/compiler.py | 6 -- ibis/backends/flink/compiler.py | 61 +++++++++--------- ibis/backends/impala/compiler.py | 6 -- ibis/backends/mssql/compiler.py | 6 -- ibis/backends/mysql/compiler.py | 6 -- ibis/backends/oracle/compiler.py | 6 -- ibis/backends/postgres/compiler.py | 10 +-- ibis/backends/pyspark/compiler.py | 6 -- ibis/backends/snowflake/compiler.py | 7 -- ibis/backends/sql/compiler.py | 95 +++++++++++++++++++++------- ibis/backends/sqlite/compiler.py | 16 ++--- ibis/backends/trino/compiler.py | 11 ++-- 17 files changed, 139 insertions(+), 156 deletions(-) diff --git a/ibis/backends/bigquery/compiler.py b/ibis/backends/bigquery/compiler.py index 2f4160ae05cc..cdff8d97435c 100644 --- a/ibis/backends/bigquery/compiler.py +++ b/ibis/backends/bigquery/compiler.py @@ -122,14 +122,6 @@ class BigQueryCompiler(SQLGlotCompiler): ops.TimestampNow: "current_timestamp", } - def _aggregate(self, funcname: str, *args, where): - func = self.f[funcname] - - if where is not None: - args = tuple(self.if_(where, arg, NULL) for arg in args) - - return func(*args, dialect=self.dialect) - @staticmethod def _minimize_spec(start, end, spec): if ( diff --git a/ibis/backends/clickhouse/compiler.py b/ibis/backends/clickhouse/compiler.py index 5f1f4b0b2fde..c778d2c5f7ee 100644 --- a/ibis/backends/clickhouse/compiler.py +++ b/ibis/backends/clickhouse/compiler.py @@ -11,17 +11,29 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util -from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler +from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import ClickHouseType from ibis.backends.sql.dialects import ClickHouse +class ClickhouseAggGen(AggGen): + def aggregate(self, compiler, name, *args, where=None): + # Clickhouse aggregate functions all have filtering variants with a + # `If` suffix (e.g. `SumIf` instead of `Sum`). + if where is not None: + name += "If" + args += (where,) + return compiler.f[name](*args, dialect=compiler.dialect) + + class ClickHouseCompiler(SQLGlotCompiler): __slots__ = () dialect = ClickHouse type_mapper = ClickHouseType + agg = ClickhouseAggGen() + UNSUPPORTED_OPS = ( ops.RowID, ops.CumeDist, @@ -104,13 +116,6 @@ class ClickHouseCompiler(SQLGlotCompiler): ops.Unnest: "arrayJoin", } - def _aggregate(self, funcname: str, *args, where): - has_filter = where is not None - func = self.f[funcname + "If" * has_filter] - args += (where,) * has_filter - - return func(*args, dialect=self.dialect) - @staticmethod def _minimize_spec(start, end, spec): if ( diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 0652706c3dab..c186a252edfd 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -11,7 +11,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.sql.compiler import FALSE, NULL, STAR, SQLGlotCompiler +from ibis.backends.sql.compiler import FALSE, NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DataFusionType from ibis.backends.sql.dialects import DataFusion from ibis.common.temporal import IntervalUnit, TimestampUnit @@ -25,6 +25,8 @@ class DataFusionCompiler(SQLGlotCompiler): dialect = DataFusion type_mapper = DataFusionType + agg = AggGen(supports_filter=True) + UNSUPPORTED_OPS = ( ops.ArgMax, ops.ArgMin, @@ -73,12 +75,6 @@ class DataFusionCompiler(SQLGlotCompiler): ops.ArrayUnion: "array_union", } - def _aggregate(self, funcname: str, *args, where): - expr = self.f[funcname](*args) - if where is not None: - return sg.exp.Filter(this=expr, expression=sg.exp.Where(this=where)) - return expr - def _to_timestamp(self, value, target_dtype, literal=False): tz = ( f'Some("{timezone}")' diff --git a/ibis/backends/druid/compiler.py b/ibis/backends/druid/compiler.py index ab4eb2020e4b..1e217ff2f84d 100644 --- a/ibis/backends/druid/compiler.py +++ b/ibis/backends/druid/compiler.py @@ -6,7 +6,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.sql.compiler import NULL, SQLGlotCompiler +from ibis.backends.sql.compiler import NULL, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DruidType from ibis.backends.sql.dialects import Druid @@ -17,6 +17,8 @@ class DruidCompiler(SQLGlotCompiler): dialect = Druid type_mapper = DruidType + agg = AggGen(supports_filter=True) + LOWERED_OPS = {ops.Capitalize: None} UNSUPPORTED_OPS = ( @@ -80,12 +82,6 @@ class DruidCompiler(SQLGlotCompiler): ops.StringContains: "contains_string", } - def _aggregate(self, funcname: str, *args, where): - expr = self.f[funcname](*args) - if where is not None: - return sg.exp.Filter(this=expr, expression=sg.exp.Where(this=where)) - return expr - def visit_Modulus(self, op, *, left, right): return self.f.anon.mod(left, right) diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index f5162844607e..f8294ab0dbf6 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -11,7 +11,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler +from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DuckDBType _INTERVAL_SUFFIXES = { @@ -33,6 +33,8 @@ class DuckDBCompiler(SQLGlotCompiler): dialect = DuckDB type_mapper = DuckDBType + agg = AggGen(supports_filter=True) + LOWERED_OPS = { ops.Sample: None, ops.StringSlice: None, @@ -85,12 +87,6 @@ class DuckDBCompiler(SQLGlotCompiler): ops.GeoY: "st_y", } - def _aggregate(self, funcname: str, *args, where): - expr = self.f[funcname](*args) - if where is not None: - return sge.Filter(this=expr, expression=sge.Where(this=where)) - return expr - def visit_StructColumn(self, op, *, names, values): return sge.Struct.from_arg_list( [ diff --git a/ibis/backends/exasol/compiler.py b/ibis/backends/exasol/compiler.py index 0940c80a182e..7a837e41b9c8 100644 --- a/ibis/backends/exasol/compiler.py +++ b/ibis/backends/exasol/compiler.py @@ -103,12 +103,6 @@ def _minimize_spec(start, end, spec): return None return spec - def _aggregate(self, funcname: str, *args, where): - func = self.f[funcname] - if where is not None: - args = tuple(self.if_(where, arg, NULL) for arg in args) - return func(*args) - @staticmethod def _gen_valid_name(name: str) -> str: """Exasol does not allow dots in quoted column names.""" diff --git a/ibis/backends/flink/compiler.py b/ibis/backends/flink/compiler.py index 211ce3c9882a..1f894487788e 100644 --- a/ibis/backends/flink/compiler.py +++ b/ibis/backends/flink/compiler.py @@ -8,7 +8,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler +from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import FlinkType from ibis.backends.sql.dialects import Flink from ibis.backends.sql.rewrites import ( @@ -18,10 +18,41 @@ ) +class FlinkAggGen(AggGen): + def aggregate(self, compiler, name, *args, where=None): + func = compiler.f[name] + if where is not None: + # Flink does support FILTER, but it's broken for: + # + # 1. certain aggregates: std/var doesn't return the right result + # 2. certain kinds of predicates: x IN y doesn't filter the right + # values out + # 3. certain aggregates AND predicates STD(w) FILTER (WHERE x IN Y) + # returns an incorrect result + # + # One solution is to try `IF(predicate, arg, NULL)`. + # + # Unfortunately that won't work without casting the NULL to a + # specific type. + # + # At this point in the Ibis compiler we don't have any of the Ibis + # operation's type information because we thrown it away. In every + # other engine Ibis supports the type of a NULL literal is inferred + # by the engine. + # + # Using a CASE statement and leaving out the explicit NULL does the + # trick for Flink. + args = tuple(sge.Case(ifs=[sge.If(this=where, true=arg)]) for arg in args) + return func(*args) + + class FlinkCompiler(SQLGlotCompiler): quoted = True dialect = Flink type_mapper = FlinkType + + agg = FlinkAggGen() + rewrites = ( exclude_unsupported_window_frame_from_row_number, exclude_unsupported_window_frame_from_ops, @@ -96,34 +127,6 @@ def POS_INF(self): def _generate_groups(groups): return groups - def _aggregate(self, funcname: str, *args, where): - func = self.f[funcname] - if where is not None: - # FILTER (WHERE ) is broken for one or both of: - # - # 1. certain aggregates: std/var doesn't return the right result - # 2. certain kinds of predicates: x IN y doesn't filter the right - # values out - # 3. certain aggregates AND predicates STD(w) FILTER (WHERE x IN Y) - # returns an incorrect result - # - # One solution is to try `IF(predicate, arg, NULL)`. - # - # Unfortunately that won't work without casting the NULL to a - # specific type. - # - # At this point in the Ibis compiler we don't have any of the Ibis - # operation's type information because we thrown it away. In every - # other engine Ibis supports the type of a NULL literal is inferred - # by the engine. - # - # Using a CASE statement and leaving out the explicit NULL does the - # trick for Flink. - # - # Le sigh. - args = tuple(sge.Case(ifs=[sge.If(this=where, true=arg)]) for arg in args) - return func(*args) - @staticmethod def _minimize_spec(start, end, spec): if ( diff --git a/ibis/backends/impala/compiler.py b/ibis/backends/impala/compiler.py index 21fa2bee1d15..49a07f761606 100644 --- a/ibis/backends/impala/compiler.py +++ b/ibis/backends/impala/compiler.py @@ -75,12 +75,6 @@ class ImpalaCompiler(SQLGlotCompiler): ops.TypeOf: "typeof", } - def _aggregate(self, funcname: str, *args, where): - if where is not None: - args = tuple(self.if_(where, arg, NULL) for arg in args) - - return self.f[funcname](*args, dialect=self.dialect) - @staticmethod def _minimize_spec(start, end, spec): # start is None means unbounded preceding diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index 920d1a99b5e1..3f40ce1528b2 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -144,12 +144,6 @@ def POS_INF(self): def NEG_INF(self): return self.f.double("-Infinity") - def _aggregate(self, funcname: str, *args, where): - func = self.f[funcname] - if where is not None: - args = tuple(self.if_(where, arg, NULL) for arg in args) - return func(*args) - @staticmethod def _generate_groups(groups): return groups diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index d8bf40d56c02..02040fd636f4 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -107,12 +107,6 @@ def POS_INF(self): ops.Log2: "log2", } - def _aggregate(self, funcname: str, *args, where): - func = self.f[funcname] - if where is not None: - args = tuple(self.if_(where, arg, NULL) for arg in args) - return func(*args) - @staticmethod def _minimize_spec(start, end, spec): if ( diff --git a/ibis/backends/oracle/compiler.py b/ibis/backends/oracle/compiler.py index e824ec1d06e3..71fe551af81c 100644 --- a/ibis/backends/oracle/compiler.py +++ b/ibis/backends/oracle/compiler.py @@ -95,12 +95,6 @@ class OracleCompiler(SQLGlotCompiler): ops.Hash: "ora_hash", } - def _aggregate(self, funcname: str, *args, where): - func = self.f[funcname] - if where is not None: - args = tuple(self.if_(where, arg) for arg in args) - return func(*args) - @staticmethod def _generate_groups(groups): return groups diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index 0a6af4b91b15..e4d843dd0c8b 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -11,7 +11,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.rules as rlz -from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler +from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import PostgresType from ibis.backends.sql.dialects import Postgres @@ -27,6 +27,8 @@ class PostgresCompiler(SQLGlotCompiler): dialect = Postgres type_mapper = PostgresType + agg = AggGen(supports_filter=True) + NAN = sge.Literal.number("'NaN'::double precision") POS_INF = sge.Literal.number("'Inf'::double precision") NEG_INF = sge.Literal.number("'-Inf'::double precision") @@ -96,12 +98,6 @@ class PostgresCompiler(SQLGlotCompiler): ops.TimeFromHMS: "make_time", } - def _aggregate(self, funcname: str, *args, where): - expr = self.f[funcname](*args) - if where is not None: - return sge.Filter(this=expr, expression=sge.Where(this=where)) - return expr - def visit_RandomUUID(self, op, **kwargs): return self.f.gen_random_uuid() diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index b7b3ca347471..a28faedde83c 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -86,12 +86,6 @@ class PySparkCompiler(SQLGlotCompiler): ops.UnwrapJSONBoolean: "unwrap_json_bool", } - def _aggregate(self, funcname: str, *args, where): - func = self.f[funcname] - if where is not None: - args = tuple(self.if_(where, arg, NULL) for arg in args) - return func(*args) - def visit_InSubquery(self, op, *, rel, needle): if op.needle.dtype.is_struct(): # construct the outer struct for pyspark diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py index c854ad25887c..eaa7e4af4d53 100644 --- a/ibis/backends/snowflake/compiler.py +++ b/ibis/backends/snowflake/compiler.py @@ -87,13 +87,6 @@ def __init__(self): super().__init__() self.f = SnowflakeFuncGen() - def _aggregate(self, funcname: str, *args, where): - if where is not None: - args = [self.if_(where, arg, NULL) for arg in args] - - func = self.f[funcname] - return func(*args) - @staticmethod def _minimize_spec(start, end, spec): if ( diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index fb0c5a974c8b..689e90184994 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -53,16 +53,77 @@ def get_leaf_classes(op): class AggGen: - __slots__ = ("aggfunc",) + """A descriptor for compiling aggregate functions. - def __init__(self, *, aggfunc: Callable) -> None: - self.aggfunc = aggfunc + Common cases can be handled by setting configuration flags, + special cases should override the `aggregate` method directly. - def __getattr__(self, name: str) -> partial: - return partial(self.aggfunc, name) + Parameters + ---------- + supports_filter + Whether the backend supports a FILTER clause in the aggregate. + Defaults to False. + """ - def __getitem__(self, key: str) -> partial: - return getattr(self, key) + class _Accessor: + """An internal type to handle getattr/getitem access.""" + + __slots__ = ("handler", "compiler") + + def __init__(self, handler: Callable, compiler: SQLGlotCompiler): + self.handler = handler + self.compiler = compiler + + def __getattr__(self, name: str) -> Callable: + return partial(self.handler, self.compiler, name) + + __getitem__ = __getattr__ + + __slots__ = ("supports_filter",) + + def __init__(self, *, supports_filter: bool = False): + self.supports_filter = supports_filter + + def __get__(self, instance, owner=None): + if instance is None: + return self + + return AggGen._Accessor(self.aggregate, instance) + + def aggregate( + self, + compiler: SQLGlotCompiler, + name: str, + *args: Any, + where: Any = None, + ): + """Compile the specified aggregate. + + Parameters + ---------- + compiler + The backend's compiler. + name + The aggregate name (e.g. `"sum"`). + args + Any arguments to pass to the aggregate. + where + An optional column filter to apply before performing the aggregate. + + """ + func = compiler.f[name] + + if where is None: + return func(*args) + + if self.supports_filter: + return sge.Filter( + this=func(*args), + expression=sge.Where(this=where), + ) + else: + args = tuple(compiler.if_(where, arg, NULL) for arg in args) + return func(*args) class VarGen: @@ -167,7 +228,10 @@ def wrapper(self, op, *, left, right): @public class SQLGlotCompiler(abc.ABC): - __slots__ = "agg", "f", "v" + __slots__ = "f", "v" + + agg = AggGen() + """A generator for handling aggregate functions""" rewrites: tuple[type[pats.Replace], ...] = ( empty_in_values_right_side, @@ -345,7 +409,6 @@ class SQLGlotCompiler(abc.ABC): lowered_ops: ClassVar[dict[type[ops.Node], pats.Replace]] = {} def __init__(self) -> None: - self.agg = AggGen(aggfunc=self._aggregate) self.f = FuncGen(copy=self.__class__.copy_func_args) self.v = VarGen() @@ -411,20 +474,6 @@ def dialect(self) -> str: def type_mapper(self) -> type[SqlglotType]: """The type mapper for the backend.""" - @abc.abstractmethod - def _aggregate(self, funcname, *args, where): - """Translate an aggregate function. - - Three flavors of filtering aggregate function inputs: - - 1. supports filter (duckdb, postgres, others) - e.g.: sum(x) filter (where predicate) - 2. use null to filter out - e.g.: sum(if(predicate, x, NULL)) - 3. clickhouse's ${func}If implementation, e.g.: - sumIf(predicate, x) - """ - # Concrete API def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If: diff --git a/ibis/backends/sqlite/compiler.py b/ibis/backends/sqlite/compiler.py index 2e7e6b279c16..c2232a885669 100644 --- a/ibis/backends/sqlite/compiler.py +++ b/ibis/backends/sqlite/compiler.py @@ -9,7 +9,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.sql.compiler import NULL, SQLGlotCompiler +from ibis.backends.sql.compiler import NULL, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import SQLiteType from ibis.backends.sql.dialects import SQLite from ibis.common.temporal import DateUnit, IntervalUnit @@ -22,6 +22,8 @@ class SQLiteCompiler(SQLGlotCompiler): dialect = SQLite type_mapper = SQLiteType + agg = AggGen(supports_filter=True) + NAN = NULL POS_INF = sge.Literal.number("1e999") NEG_INF = sge.Literal.number("-1e999") @@ -97,12 +99,6 @@ class SQLiteCompiler(SQLGlotCompiler): ops.Date: "date", } - def _aggregate(self, funcname: str, *args, where): - expr = self.f[funcname](*args) - if where is not None: - return sge.Filter(this=expr, expression=sge.Where(this=where)) - return expr - def visit_Log10(self, op, *, arg): return self.f.anon.log10(arg) @@ -222,7 +218,7 @@ def _visit_arg_reduction(self, func, op, *, arg, key, where): if op.where is not None: cond = sg.and_(cond, where) - agg = self._aggregate(func, key, where=cond) + agg = self.agg[func](key, where=cond) return self.f.anon.json_extract(self.f.json_array(arg, agg), "$[0]") def visit_UnwrapJSONString(self, op, *, arg): @@ -254,10 +250,10 @@ def visit_UnwrapJSONBoolean(self, op, *, arg): ) def visit_Variance(self, op, *, arg, how, where): - return self._aggregate(f"_ibis_var_{op.how}", arg, where=where) + return self.agg[f"_ibis_var_{op.how}"](arg, where=where) def visit_StandardDev(self, op, *, arg, how, where): - var = self._aggregate(f"_ibis_var_{op.how}", arg, where=where) + var = self.agg[f"_ibis_var_{op.how}"](arg, where=where) return self.f.sqrt(var) def visit_ApproxCountDistinct(self, op, *, arg, where): diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index 7c0c7634227e..d843bfa83e6e 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -10,7 +10,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.sql.compiler import FALSE, NULL, STAR, SQLGlotCompiler +from ibis.backends.sql.compiler import FALSE, NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import TrinoType from ibis.backends.sql.dialects import Trino from ibis.backends.sql.rewrites import exclude_unsupported_window_frame_from_ops @@ -21,6 +21,9 @@ class TrinoCompiler(SQLGlotCompiler): dialect = Trino type_mapper = TrinoType + + agg = AggGen(supports_filter=True) + rewrites = ( exclude_unsupported_window_frame_from_ops, *SQLGlotCompiler.rewrites, @@ -83,12 +86,6 @@ class TrinoCompiler(SQLGlotCompiler): ops.ExtractIsoYear: "year_of_week", } - def _aggregate(self, funcname: str, *args, where): - expr = self.f[funcname](*args) - if where is not None: - return sge.Filter(this=expr, expression=sge.Where(this=where)) - return expr - @staticmethod def _minimize_spec(start, end, spec): if (