From cd1ecab1bcf43891ce8a5b11b943db639cfe8cf2 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Tue, 6 Aug 2024 11:42:10 -0500 Subject: [PATCH] feat(api): support `ignore_null` in `first`/`last` --- ibis/backends/dask/executor.py | 36 ++++++ ibis/backends/dask/kernels.py | 12 +- ibis/backends/pandas/executor.py | 34 +++++- ibis/backends/pandas/kernels.py | 13 +-- ibis/backends/polars/compiler.py | 5 +- ibis/backends/sql/compilers/base.py | 2 - .../sql/compilers/bigquery/__init__.py | 35 ++++-- ibis/backends/sql/compilers/clickhouse.py | 22 +++- ibis/backends/sql/compilers/datafusion.py | 18 +-- ibis/backends/sql/compilers/druid.py | 2 - ibis/backends/sql/compilers/duckdb.py | 18 +-- ibis/backends/sql/compilers/exasol.py | 16 ++- ibis/backends/sql/compilers/flink.py | 16 ++- ibis/backends/sql/compilers/impala.py | 2 - ibis/backends/sql/compilers/mssql.py | 2 - ibis/backends/sql/compilers/mysql.py | 2 - ibis/backends/sql/compilers/oracle.py | 2 - ibis/backends/sql/compilers/postgres.py | 18 ++- ibis/backends/sql/compilers/pyspark.py | 30 +++-- ibis/backends/sql/compilers/risingwave.py | 12 +- ibis/backends/sql/compilers/snowflake.py | 22 ++-- ibis/backends/sql/compilers/sqlite.py | 10 +- ibis/backends/sql/compilers/trino.py | 18 +-- ibis/backends/sqlite/udf.py | 27 +++-- ibis/backends/tests/test_aggregation.py | 110 +++++++++++++++--- ibis/expr/operations/reductions.py | 4 +- ibis/expr/types/generic.py | 30 +++-- 27 files changed, 384 insertions(+), 134 deletions(-) diff --git a/ibis/backends/dask/executor.py b/ibis/backends/dask/executor.py index 12d975d79966..db5b78d97fe1 100644 --- a/ibis/backends/dask/executor.py +++ b/ibis/backends/dask/executor.py @@ -203,6 +203,42 @@ def agg(df): return agg + @classmethod + def visit(cls, op: ops.First, arg, where, order_by, include_null): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + + def first(df): + def inner(arg): + if not include_null: + arg = arg.dropna() + return arg.iat[0] if len(arg) else None + + return df.reduction(inner) if isinstance(df, dd.Series) else inner(df) + + return cls.agg(first, arg, where) + + @classmethod + def visit(cls, op: ops.Last, arg, where, order_by, include_null): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + + def last(df): + def inner(arg): + if not include_null: + arg = arg.dropna() + return arg.iat[-1] if len(arg) else None + + return df.reduction(inner) if isinstance(df, dd.Series) else inner(df) + + return cls.agg(last, arg, where) + @classmethod def visit(cls, op: ops.Correlation, left, right, where, how): if how == "pop": diff --git a/ibis/backends/dask/kernels.py b/ibis/backends/dask/kernels.py index 12a1a782ab01..f323cb205845 100644 --- a/ibis/backends/dask/kernels.py +++ b/ibis/backends/dask/kernels.py @@ -17,13 +17,6 @@ } -def maybe_pandas_reduction(func): - def inner(df): - return df.reduction(func) if isinstance(df, dd.Series) else func(df) - - return inner - - reductions = { **pandas_kernels.reductions, ops.Mode: lambda x: x.mode().loc[0], @@ -31,10 +24,7 @@ def inner(df): ops.BitAnd: lambda x: x.reduction(np.bitwise_and.reduce), ops.BitOr: lambda x: x.reduction(np.bitwise_or.reduce), ops.BitXor: lambda x: x.reduction(np.bitwise_xor.reduce), - ops.Arbitrary: lambda x: x.reduction(pandas_kernels.first), - # Window functions are calculated locally using pandas - ops.Last: maybe_pandas_reduction(pandas_kernels.last), - ops.First: maybe_pandas_reduction(pandas_kernels.first), + ops.Arbitrary: lambda x: x.reduction(pandas_kernels.arbitrary), } serieswise = { diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index 5e7a79453701..4762132866c3 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -320,16 +320,46 @@ def visit(cls, op: ops.StandardDev, arg, where, how): return cls.agg(lambda x: x.std(ddof=ddof), arg, where) @classmethod - def visit(cls, op: ops.ArrayCollect, arg, where, order_by, ignore_null): + def visit(cls, op: ops.ArrayCollect, arg, where, order_by, include_null): if order_by: raise UnsupportedOperationError( "ordering of order-sensitive aggregations via `order_by` is " "not supported for this backend" ) return cls.agg( - (lambda x: x.dropna().tolist() if ignore_null else x.tolist()), arg, where + (lambda x: x.tolist() if include_null else x.dropna().tolist()), arg, where ) + @classmethod + def visit(cls, op: ops.First, arg, where, order_by, include_null): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + + def first(arg): + if not include_null: + arg = arg.dropna() + return arg.iat[0] if len(arg) else None + + return cls.agg(first, arg, where) + + @classmethod + def visit(cls, op: ops.Last, arg, where, order_by, include_null): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + + def last(arg): + if not include_null: + arg = arg.dropna() + return arg.iat[-1] if len(arg) else None + + return cls.agg(last, arg, where) + @classmethod def visit(cls, op: ops.Correlation, left, right, where, how): if where is None: diff --git a/ibis/backends/pandas/kernels.py b/ibis/backends/pandas/kernels.py index c6feb6c3e886..0d2bc1db1de8 100644 --- a/ibis/backends/pandas/kernels.py +++ b/ibis/backends/pandas/kernels.py @@ -260,18 +260,11 @@ def round_serieswise(arg, digits): return np.round(arg, digits).astype("float64") -def first(arg): - # first excludes null values unless they're all null +def arbitrary(arg): arg = arg.dropna() return arg.iat[0] if len(arg) else None -def last(arg): - # last excludes null values unless they're all null - arg = arg.dropna() - return arg.iat[-1] if len(arg) else None - - reductions = { ops.Min: lambda x: x.min(), ops.Max: lambda x: x.max(), @@ -286,9 +279,7 @@ def last(arg): ops.BitAnd: lambda x: np.bitwise_and.reduce(x.values), ops.BitOr: lambda x: np.bitwise_or.reduce(x.values), ops.BitXor: lambda x: np.bitwise_xor.reduce(x.values), - ops.Last: last, - ops.First: first, - ops.Arbitrary: first, + ops.Arbitrary: arbitrary, ops.CountDistinct: lambda x: x.nunique(), ops.ApproxCountDistinct: lambda x: x.nunique(), } diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index d0ac2163f796..1ecdc38b4c22 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -741,8 +741,7 @@ def execute_reduction(op, **kw): def execute_first_last(op, **kw): arg = translate(op.arg, **kw) - # polars doesn't ignore nulls by default for these methods - predicate = arg.is_not_null() + predicate = True if getattr(op, "include_null", False) else arg.is_not_null() if op.where is not None: predicate &= translate(op.where, **kw) @@ -991,7 +990,7 @@ def array_column(op, **kw): def array_collect(op, in_group_by=False, **kw): arg = translate(op.arg, **kw) - predicate = arg.is_not_null() if op.ignore_null else True + predicate = True if op.include_null else arg.is_not_null() if op.where is not None: predicate &= translate(op.where, **kw) diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 418268b8b98c..74a39d3f9db1 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -330,7 +330,6 @@ class SQLGlotCompiler(abc.ABC): ops.Degrees: "degrees", ops.DenseRank: "dense_rank", ops.Exp: "exp", - ops.First: "first", FirstValue: "first_value", ops.GroupConcat: "group_concat", ops.IfElse: "if", @@ -338,7 +337,6 @@ class SQLGlotCompiler(abc.ABC): ops.IsNan: "isnan", ops.JSONGetItem: "json_extract", ops.LPad: "lpad", - ops.Last: "last", LastValue: "last_value", ops.Levenshtein: "levenshtein", ops.Ln: "ln", diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index b5818e35468c..96bfc46cafae 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -440,9 +440,14 @@ def visit_StringToTimestamp(self, op, *, arg, format_str): return self.f.parse_timestamp(format_str, arg, timezone) return self.f.parse_datetime(format_str, arg) - def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + if where is not None and include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by bigquery" + ) out = self.agg.array_agg(arg, where=where, order_by=order_by) - if ignore_null: + if not include_null: out = sge.IgnoreNulls(this=out) return out @@ -690,26 +695,40 @@ def visit_TimestampRange(self, op, *, start, stop, step): self.f.generate_timestamp_array, start, stop, step, op.step.dtype ) - def visit_First(self, op, *, arg, where, order_by): + def visit_First(self, op, *, arg, where, order_by, include_null): if where is not None: arg = self.if_(where, arg, NULL) + if include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by bigquery" + ) if order_by: arg = sge.Order(this=arg, expressions=order_by) - array = self.f.array_agg( - sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)), - ) + if not include_null: + arg = sge.IgnoreNulls(this=arg) + + array = self.f.array_agg(sge.Limit(this=arg, expression=sge.convert(1))) return array[self.f.safe_offset(0)] - def visit_Last(self, op, *, arg, where, order_by): + def visit_Last(self, op, *, arg, where, order_by, include_null): if where is not None: arg = self.if_(where, arg, NULL) + if include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by bigquery" + ) if order_by: arg = sge.Order(this=arg, expressions=order_by) - array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg))) + if not include_null: + arg = sge.IgnoreNulls(this=arg) + + array = self.f.array_reverse(self.f.array_agg(arg)) return array[self.f.safe_offset(0)] def visit_ArrayFilter(self, op, *, arg, body, param): diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index 00236f96df99..b4a6dd671898 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -90,13 +90,11 @@ class ClickHouseCompiler(SQLGlotCompiler): ops.ExtractWeekOfYear: "toISOWeek", ops.ExtractYear: "toYear", ops.ExtractIsoYear: "toISOYear", - ops.First: "any", ops.IntegerRange: "range", ops.IsInf: "isInfinite", ops.IsNan: "isNaN", ops.IsNull: "isNull", ops.LStrip: "trimLeft", - ops.Last: "anyLast", ops.Ln: "log", ops.Log10: "log10", ops.MapKeys: "mapKeys", @@ -603,13 +601,27 @@ def visit_ArrayUnion(self, op, *, left, right): def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str: return self.f.arrayZip(*arg) - def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): - if not ignore_null: + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + if include_null: raise com.UnsupportedOperationError( - "`ignore_null=False` is not supported by the pyspark backend" + "`include_null=True` is not supported by the clickhouse backend" ) return self.agg.groupArray(arg, where=where, order_by=order_by) + def visit_First(self, op, *, arg, where, order_by, include_null): + if include_null: + raise com.UnsupportedOperationError( + "`include_null=True` is not supported by the clickhouse backend" + ) + return self.agg.any(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, include_null): + if include_null: + raise com.UnsupportedOperationError( + "`include_null=True` is not supported by the clickhouse backend" + ) + return self.agg.anyLast(arg, where=where, order_by=order_by) + def visit_CountDistinctStar( self, op: ops.CountDistinctStar, *, where, **_: Any ) -> str: diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index acc6dfc22431..4e8acd2fad29 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -325,8 +325,8 @@ def visit_ArrayRepeat(self, op, *, arg, times): def visit_ArrayPosition(self, op, *, arg, other): return self.f.coalesce(self.f.array_position(arg, other), 0) - def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): - if ignore_null: + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + if not include_null: cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.array_agg(arg, where=where, order_by=order_by) @@ -425,14 +425,16 @@ def visit_StringConcat(self, op, *, arg): sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg) ) - def visit_First(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_First(self, op, *, arg, where, order_by, include_null): + if not include_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.first_value(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_Last(self, op, *, arg, where, order_by, include_null): + if not include_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last_value(arg, where=where, order_by=order_by) def visit_Aggregate(self, op, *, parent, groups, metrics): diff --git a/ibis/backends/sql/compilers/druid.py b/ibis/backends/sql/compilers/druid.py index 6479628d9acb..29cb7dfec7ac 100644 --- a/ibis/backends/sql/compilers/druid.py +++ b/ibis/backends/sql/compilers/druid.py @@ -42,11 +42,9 @@ class DruidCompiler(SQLGlotCompiler): ops.DateFromYMD, ops.DayOfWeekIndex, ops.DayOfWeekName, - ops.First, ops.IntervalFromInteger, ops.IsNan, ops.IsInf, - ops.Last, ops.Levenshtein, ops.Median, ops.MultiQuantile, diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index 7498ab1f0277..af5a0757abe1 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -148,8 +148,8 @@ def visit_ArrayDistinct(self, op, *, arg): ), ) - def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): - if ignore_null: + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + if not include_null: cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.array_agg(arg, where=where, order_by=order_by) @@ -510,14 +510,16 @@ def visit_RegexReplace(self, op, *, arg, pattern, replacement): arg, pattern, replacement, "g", dialect=self.dialect ) - def visit_First(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_First(self, op, *, arg, where, order_by, include_null): + if not include_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.first(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_Last(self, op, *, arg, where, order_by, include_null): + if not include_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last(arg, where=where, order_by=order_by) def visit_Quantile(self, op, *, arg, quantile, where): diff --git a/ibis/backends/sql/compilers/exasol.py b/ibis/backends/sql/compilers/exasol.py index bdf690a7ff2b..70d9486f36d3 100644 --- a/ibis/backends/sql/compilers/exasol.py +++ b/ibis/backends/sql/compilers/exasol.py @@ -87,8 +87,6 @@ class ExasolCompiler(SQLGlotCompiler): ops.Log10: "log10", ops.All: "min", ops.Any: "max", - ops.First: "first_value", - ops.Last: "last_value", } @staticmethod @@ -136,6 +134,20 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by): return sge.GroupConcat(this=arg, separator=sep) + def visit_First(self, op, *, arg, where, order_by, include_null): + if include_null: + raise com.UnsupportedOperationError( + "`include_null=True` is not supported by the exasol backend" + ) + return self.agg.first_value(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, include_null): + if include_null: + raise com.UnsupportedOperationError( + "`include_null=True` is not supported by the exasol backend" + ) + return self.agg.last_value(arg, where=where, order_by=order_by) + def visit_StartsWith(self, op, *, arg, start): return self.f.left(arg, self.f.length(start)).eq(start) diff --git a/ibis/backends/sql/compilers/flink.py b/ibis/backends/sql/compilers/flink.py index fb1fad34cfff..b00c248a7f8a 100644 --- a/ibis/backends/sql/compilers/flink.py +++ b/ibis/backends/sql/compilers/flink.py @@ -104,8 +104,6 @@ class FlinkCompiler(SQLGlotCompiler): ops.ArrayRemove: "array_remove", ops.ArrayUnion: "array_union", ops.ExtractDayOfYear: "dayofyear", - ops.First: "first_value", - ops.Last: "last_value", ops.MapKeys: "map_keys", ops.MapValues: "map_values", ops.Power: "power", @@ -307,6 +305,20 @@ def visit_ArraySlice(self, op, *, arg, start, stop): return self.f.array_slice(*args) + def visit_First(self, op, *, arg, where, order_by, include_null): + if include_null: + raise com.UnsupportedOperationError( + "`include_null=True` is not supported by the flink backend" + ) + return self.agg.first_value(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, include_null): + if include_null: + raise com.UnsupportedOperationError( + "`include_null=True` is not supported by the flink backend" + ) + return self.agg.last_value(arg, where=where, order_by=order_by) + def visit_Not(self, op, *, arg): return sg.not_(self.cast(arg, dt.boolean)) diff --git a/ibis/backends/sql/compilers/impala.py b/ibis/backends/sql/compilers/impala.py index bae861126d16..8a77ebb2bc94 100644 --- a/ibis/backends/sql/compilers/impala.py +++ b/ibis/backends/sql/compilers/impala.py @@ -31,8 +31,6 @@ class ImpalaCompiler(SQLGlotCompiler): ops.Covariance, ops.DateDelta, ops.ExtractDayOfYear, - ops.First, - ops.Last, ops.Levenshtein, ops.Map, ops.Median, diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index d7b815252e9e..e16e969abfc0 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -90,14 +90,12 @@ class MSSQLCompiler(SQLGlotCompiler): ops.DateDiff, ops.DateSub, ops.EndsWith, - ops.First, ops.IntervalAdd, ops.IntervalFromInteger, ops.IntervalMultiply, ops.IntervalSubtract, ops.IsInf, ops.IsNan, - ops.Last, ops.LPad, ops.Levenshtein, ops.Map, diff --git a/ibis/backends/sql/compilers/mysql.py b/ibis/backends/sql/compilers/mysql.py index 56c6f799c89a..d5367f9495a6 100644 --- a/ibis/backends/sql/compilers/mysql.py +++ b/ibis/backends/sql/compilers/mysql.py @@ -71,8 +71,6 @@ def POS_INF(self): ops.ArrayFlatten, ops.ArrayMap, ops.Covariance, - ops.First, - ops.Last, ops.Levenshtein, ops.Median, ops.Mode, diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index 98bea4d2d90b..c22a40d1dea6 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -55,8 +55,6 @@ class OracleCompiler(SQLGlotCompiler): ops.ArrayFlatten, ops.ArrayMap, ops.ArrayStringJoin, - ops.First, - ops.Last, ops.Mode, ops.MultiQuantile, ops.RegexSplit, diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index f145ac779111..6bdd3f0fa388 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -355,12 +355,26 @@ def visit_ArrayIntersect(self, op, *, left, right): ) ) - def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): - if ignore_null: + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + if not include_null: cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_First(self, op, *, arg, where, order_by, include_null): + if include_null: + raise com.UnsupportedOperationError( + "`include_null=True` is not supported by the postgres backend" + ) + return self.agg.first(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, include_null): + if include_null: + raise com.UnsupportedOperationError( + "`include_null=True` is not supported by the postgres backend" + ) + return self.agg.last(arg, where=where, order_by=order_by) + def visit_Log2(self, op, *, arg): return self.cast( self.f.log( diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index 9a4c292d5d11..5b7588ea2762 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -240,11 +240,27 @@ def visit_FirstValue(self, op, *, arg): def visit_LastValue(self, op, *, arg): return sge.IgnoreNulls(this=self.f.last(arg)) - def visit_First(self, op, *, arg, where, order_by): - return sge.IgnoreNulls(this=self.agg.first(arg, where=where, order_by=order_by)) + def visit_First(self, op, *, arg, where, order_by, include_null): + if where is not None and include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by pyspark" + ) + out = self.agg.first(arg, where=where, order_by=order_by) + if not include_null: + out = sge.IgnoreNulls(this=out) + return out - def visit_Last(self, op, *, arg, where, order_by): - return sge.IgnoreNulls(this=self.agg.last(arg, where=where, order_by=order_by)) + def visit_Last(self, op, *, arg, where, order_by, include_null): + if where is not None and include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by pyspark" + ) + out = self.agg.last(arg, where=where, order_by=order_by) + if not include_null: + out = sge.IgnoreNulls(this=out) + return out def visit_Arbitrary(self, op, *, arg, where): # For Spark>=3.4 we could use any_value here @@ -397,10 +413,10 @@ def visit_ArrayContains(self, op, *, arg, other): def visit_ArrayStringJoin(self, op, *, arg, sep): return self.f.concat_ws(sep, arg) - def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): - if not ignore_null: + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + if include_null: raise com.UnsupportedOperationError( - "`ignore_null=False` is not supported by the pyspark backend" + "`include_null=True` is not supported by the pyspark backend" ) return self.agg.array_agg(arg, where=where, order_by=order_by) diff --git a/ibis/backends/sql/compilers/risingwave.py b/ibis/backends/sql/compilers/risingwave.py index 35f741c17499..61f89d4153ef 100644 --- a/ibis/backends/sql/compilers/risingwave.py +++ b/ibis/backends/sql/compilers/risingwave.py @@ -35,14 +35,22 @@ class RisingWaveCompiler(PostgresCompiler): def visit_DateNow(self, op): return self.cast(sge.CurrentTimestamp(), dt.date) - def visit_First(self, op, *, arg, where, order_by): + def visit_First(self, op, *, arg, where, order_by, include_null): + if include_null: + raise com.UnsupportedOperationError( + "`include_null=True` is not supported by the risingwave backend" + ) if not order_by: raise com.UnsupportedOperationError( "RisingWave requires an `order_by` be specified in `first`" ) return self.agg.first_value(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where, order_by): + def visit_Last(self, op, *, arg, where, order_by, include_null): + if include_null: + raise com.UnsupportedOperationError( + "`include_null=True` is not supported by the risingwave backend" + ) if not order_by: raise com.UnsupportedOperationError( "RisingWave requires an `order_by` be specified in `last`" diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index f6fcad7b26d4..e1929bd57d35 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -455,10 +455,10 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short]) - def _array_collect(self, *, arg, where, order_by, ignore_null=True): - if not ignore_null: + def _array_collect(self, *, arg, where, order_by, include_null): + if include_null: raise com.UnsupportedOperationError( - "`ignore_null=False` is not supported by the snowflake backend" + "`include_null=True` is not supported by the snowflake backend" ) if where is not None: @@ -471,17 +471,21 @@ def _array_collect(self, *, arg, where, order_by, ignore_null=True): return out - def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): return self._array_collect( - arg=arg, where=where, order_by=order_by, ignore_null=ignore_null + arg=arg, where=where, order_by=order_by, include_null=include_null ) - def visit_First(self, op, *, arg, where, order_by): - out = self._array_collect(arg=arg, where=where, order_by=order_by) + def visit_First(self, op, *, arg, where, order_by, include_null): + out = self._array_collect( + arg=arg, where=where, order_by=order_by, include_null=include_null + ) return self.f.get(out, 0) - def visit_Last(self, op, *, arg, where, order_by): - out = self._array_collect(arg=arg, where=where, order_by=order_by) + def visit_Last(self, op, *, arg, where, order_by, include_null): + out = self._array_collect( + arg=arg, where=where, order_by=order_by, include_null=include_null + ) return self.f.get(out, self.f.array_size(out) - 1) def visit_GroupConcat(self, op, *, arg, where, sep, order_by): diff --git a/ibis/backends/sql/compilers/sqlite.py b/ibis/backends/sql/compilers/sqlite.py index 9f0d4bf7c1dc..0d4db351c3b1 100644 --- a/ibis/backends/sql/compilers/sqlite.py +++ b/ibis/backends/sql/compilers/sqlite.py @@ -90,8 +90,6 @@ class SQLiteCompiler(SQLGlotCompiler): ops.BitOr: "_ibis_bit_or", ops.BitAnd: "_ibis_bit_and", ops.BitXor: "_ibis_bit_xor", - ops.First: "_ibis_first", - ops.Last: "_ibis_last", ops.Mode: "_ibis_mode", ops.Time: "time", ops.Date: "date", @@ -249,6 +247,14 @@ def visit_UnwrapJSONBoolean(self, op, *, arg): NULL, ) + def visit_First(self, op, *, arg, where, order_by, include_null): + func = "_ibis_first_include_null" if include_null else "_ibis_first" + return self.agg[func](arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, include_null): + func = "_ibis_last_include_null" if include_null else "_ibis_last" + return self.agg[func](arg, where=where, order_by=order_by) + def visit_Variance(self, op, *, arg, how, where): return self.agg[f"_ibis_var_{op.how}"](arg, where=where) diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index f35653c79712..4a19b9a37436 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -176,8 +176,8 @@ def visit_ArrayContains(self, op, *, arg, other): NULL, ) - def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): - if ignore_null: + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + if not include_null: cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.array_agg(arg, where=where, order_by=order_by) @@ -373,16 +373,18 @@ def visit_StringAscii(self, op, *, arg): def visit_ArrayStringJoin(self, op, *, sep, arg): return self.f.array_join(arg, sep) - def visit_First(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_First(self, op, *, arg, where, order_by, include_null): + if not include_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.f.element_at( self.agg.array_agg(arg, where=where, order_by=order_by), 1 ) - def visit_Last(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_Last(self, op, *, arg, where, order_by, include_null): + if not include_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.f.element_at( self.agg.array_agg(arg, where=where, order_by=order_by), -1 ) diff --git a/ibis/backends/sqlite/udf.py b/ibis/backends/sqlite/udf.py index 49c775e9e8e7..9e3fdff258b3 100644 --- a/ibis/backends/sqlite/udf.py +++ b/ibis/backends/sqlite/udf.py @@ -1,6 +1,5 @@ from __future__ import annotations -import abc import functools import inspect import math @@ -31,6 +30,7 @@ class _UDF(NamedTuple): _SQLITE_UDF_REGISTRY = {} _SQLITE_UDAF_REGISTRY = {} +UNSET = object() def ignore_nulls(f): @@ -441,14 +441,11 @@ def __init__(self): super().__init__(operator.xor) -class _ibis_first_last(abc.ABC): - def __init__(self) -> None: +class _ibis_first_last: + def __init__(self): self.value = None - @abc.abstractmethod - def step(self, value): ... - - def finalize(self) -> int | None: + def finalize(self): return self.value @@ -459,6 +456,16 @@ def step(self, value): self.value = value +@udaf +class _ibis_first_include_null(_ibis_first_last): + def __init__(self): + self.value = UNSET + + def step(self, value): + if self.value is UNSET: + self.value = value + + @udaf class _ibis_last(_ibis_first_last): def step(self, value): @@ -466,6 +473,12 @@ def step(self, value): self.value = value +@udaf +class _ibis_last_include_null(_ibis_first_last): + def step(self, value): + self.value = value + + def register_all(con): """Register all udf and udaf with the connection. diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 15d7cbcb6018..ecf631ae623d 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -616,13 +616,43 @@ def test_reduction_ops( ["datafusion"], raises=Exception, reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + strict=False, ) ], ), True, ], ) -def test_first_last(backend, alltypes, method, filtered): +@pytest.mark.parametrize( + "include_null", + [ + False, + param( + True, + marks=[ + pytest.mark.notimpl( + [ + "clickhouse", + "exasol", + "flink", + "postgres", + "risingwave", + "snowflake", + ], + raises=com.UnsupportedOperationError, + reason="`include_null=True` is not supported", + ), + pytest.mark.notimpl( + ["bigquery", "pyspark"], + raises=com.UnsupportedOperationError, + reason="Can't mix `where` and `include_null=True`", + strict=False, + ), + ], + ), + ], +) +def test_first_last(backend, alltypes, method, filtered, include_null): # `first` and `last` effectively choose an arbitrary value when no # additional order is specified. *Most* backends will result in the # first/last element in a column being selected (at least when operating on @@ -640,9 +670,13 @@ def test_first_last(backend, alltypes, method, filtered): t = alltypes.mutate(new=new) - expr = getattr(t.new, method)(where=where) + expr = getattr(t.new, method)(where=where, include_null=include_null) res = expr.execute() - assert res == 30 + if include_null: + # no ordering, so technically could be either 30 or NULL + assert res == 30 or pd.isna(res) + else: + assert res == 30 @pytest.mark.notimpl( @@ -672,23 +706,59 @@ def test_first_last(backend, alltypes, method, filtered): ["datafusion"], raises=Exception, reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + strict=False, ) ], ), True, ], ) -def test_first_last_ordered(backend, alltypes, method, filtered): +@pytest.mark.parametrize( + "include_null", + [ + False, + param( + True, + marks=[ + pytest.mark.notimpl( + [ + "clickhouse", + "exasol", + "flink", + "postgres", + "risingwave", + "snowflake", + ], + raises=com.UnsupportedOperationError, + reason="`include_null=True` is not supported", + ), + pytest.mark.notimpl( + ["bigquery", "pyspark"], + raises=com.UnsupportedOperationError, + reason="Can't mix `where` and `include_null=True`", + strict=False, + ), + ], + ), + ], +) +def test_first_last_ordered(backend, alltypes, method, filtered, include_null): t = alltypes.mutate(new=alltypes.int_col.nullif(0).nullif(9)) - where = None - sol = 1 if method == "last" else 8 if filtered: - where = _.int_col != sol + where = _.int_col != (1 if method == "last" else 8) sol = 2 if method == "last" else 7 + else: + where = None + sol = 1 if method == "last" else 8 - expr = getattr(t.new, method)(where=where, order_by=t.int_col.desc()) + expr = getattr(t.new, method)( + where=where, order_by=t.int_col.desc(), include_null=include_null + ) res = expr.execute() - assert res == sol + if include_null: + assert pd.isna(res) + else: + assert res == sol @pytest.mark.notimpl( @@ -1399,31 +1469,37 @@ def test_collect_ordered(alltypes, df, filtered): ], ) @pytest.mark.parametrize( - "ignore_null", + "include_null", [ - True, + False, param( - False, + True, marks=[ pytest.mark.notimpl( ["clickhouse", "pyspark", "snowflake"], raises=com.UnsupportedOperationError, - reason="`ignore_null=False` is not supported", - ) + reason="`include_null=True` is not supported", + ), + pytest.mark.notimpl( + ["bigquery"], + raises=com.UnsupportedOperationError, + reason="Can't mix `where` and `include_null=True`", + strict=False, + ), ], ), ], ) -def test_collect(alltypes, df, filtered, ignore_null): +def test_collect(alltypes, df, filtered, include_null): ibis_cond = (_.id % 13 == 0) if filtered else None pd_cond = (df.id % 13 == 0) if filtered else slice(None) res = ( alltypes.string_col.nullif("3") - .collect(where=ibis_cond, ignore_null=ignore_null) + .collect(where=ibis_cond, include_null=include_null) .length() .execute() ) - vals = df.string_col[(df.string_col != "3")] if ignore_null else df.string_col + vals = df.string_col if include_null else df.string_col[(df.string_col != "3")] sol = len(vals[pd_cond]) assert res == sol diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index f5b1128a8694..9161ee563564 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -81,6 +81,7 @@ class First(Filterable, Reduction): arg: Column[dt.Any] order_by: VarTuple[SortKey] = () + include_null: bool = False dtype = rlz.dtype_like("arg") @@ -91,6 +92,7 @@ class Last(Filterable, Reduction): arg: Column[dt.Any] order_by: VarTuple[SortKey] = () + include_null: bool = False dtype = rlz.dtype_like("arg") @@ -368,7 +370,7 @@ class ArrayCollect(Filterable, Reduction): arg: Column order_by: VarTuple[SortKey] = () - ignore_null: bool = True + include_null: bool = False @attribute def dtype(self): diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 9cb19b73c5b1..826cb2c74c92 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1021,7 +1021,7 @@ def collect( self, where: ir.BooleanValue | None = None, order_by: Any = None, - ignore_null: bool = True, + include_null: bool = False, ) -> ir.ArrayScalar: """Aggregate this expression's elements into an array. @@ -1036,9 +1036,9 @@ def collect( An ordering key (or keys) to use to order the rows before aggregating. If not provided, the order of the items in the result is undefined and backend specific. - ignore_null - Whether to ignore null values when performing this aggregation. Set - to `False` to include nulls in the result. + include_null + Whether to include null values when performing this aggregation. Set + to `True` to include nulls in the result. Returns ------- @@ -1099,7 +1099,7 @@ def collect( self, where=self._bind_to_parent_table(where), order_by=self._bind_order_by(order_by), - ignore_null=ignore_null, + include_null=include_null, ).to_expr() def identical_to(self, other: Value) -> ir.BooleanValue: @@ -2116,7 +2116,10 @@ def value_counts(self) -> ir.Table: return self.as_table().group_by(name).aggregate(metric) def first( - self, where: ir.BooleanValue | None = None, order_by: Any = None + self, + where: ir.BooleanValue | None = None, + order_by: Any = None, + include_null: bool = False, ) -> Value: """Return the first value of a column. @@ -2129,6 +2132,9 @@ def first( An ordering key (or keys) to use to order the rows before aggregating. If not provided, the meaning of `first` is undefined and will be backend specific. + include_null + Whether to include null values when performing this aggregation. Set + to `True` to include nulls in the result. Examples -------- @@ -2159,9 +2165,15 @@ def first( self, where=self._bind_to_parent_table(where), order_by=self._bind_order_by(order_by), + include_null=include_null, ).to_expr() - def last(self, where: ir.BooleanValue | None = None, order_by: Any = None) -> Value: + def last( + self, + where: ir.BooleanValue | None = None, + order_by: Any = None, + include_null: bool = False, + ) -> Value: """Return the last value of a column. Parameters @@ -2173,6 +2185,9 @@ def last(self, where: ir.BooleanValue | None = None, order_by: Any = None) -> Va An ordering key (or keys) to use to order the rows before aggregating. If not provided, the meaning of `last` is undefined and will be backend specific. + include_null + Whether to include null values when performing this aggregation. Set + to `True` to include nulls in the result. Examples -------- @@ -2203,6 +2218,7 @@ def last(self, where: ir.BooleanValue | None = None, order_by: Any = None) -> Va self, where=self._bind_to_parent_table(where), order_by=self._bind_order_by(order_by), + include_null=include_null, ).to_expr() def rank(self) -> ir.IntegerColumn: