diff --git a/ibis/backends/dask/executor.py b/ibis/backends/dask/executor.py index 12d975d799667..4fb6405397e8d 100644 --- a/ibis/backends/dask/executor.py +++ b/ibis/backends/dask/executor.py @@ -203,6 +203,42 @@ def agg(df): return agg + @classmethod + def visit(cls, op: ops.First, arg, where, order_by, ignore_null): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + + def first(df): + def inner(arg): + if ignore_null: + arg = arg.dropna() + return arg.iat[0] if len(arg) else None + + return df.reduction(inner) if isinstance(df, dd.Series) else inner(df) + + return cls.agg(first, arg, where) + + @classmethod + def visit(cls, op: ops.Last, arg, where, order_by, ignore_null): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + + def last(df): + def inner(arg): + if ignore_null: + arg = arg.dropna() + return arg.iat[-1] if len(arg) else None + + return df.reduction(inner) if isinstance(df, dd.Series) else inner(df) + + return cls.agg(last, arg, where) + @classmethod def visit(cls, op: ops.Correlation, left, right, where, how): if how == "pop": diff --git a/ibis/backends/dask/kernels.py b/ibis/backends/dask/kernels.py index 12a1a782ab01d..f323cb2058456 100644 --- a/ibis/backends/dask/kernels.py +++ b/ibis/backends/dask/kernels.py @@ -17,13 +17,6 @@ } -def maybe_pandas_reduction(func): - def inner(df): - return df.reduction(func) if isinstance(df, dd.Series) else func(df) - - return inner - - reductions = { **pandas_kernels.reductions, ops.Mode: lambda x: x.mode().loc[0], @@ -31,10 +24,7 @@ def inner(df): ops.BitAnd: lambda x: x.reduction(np.bitwise_and.reduce), ops.BitOr: lambda x: x.reduction(np.bitwise_or.reduce), ops.BitXor: lambda x: x.reduction(np.bitwise_xor.reduce), - ops.Arbitrary: lambda x: x.reduction(pandas_kernels.first), - # Window functions are calculated locally using pandas - ops.Last: maybe_pandas_reduction(pandas_kernels.last), - ops.First: maybe_pandas_reduction(pandas_kernels.first), + ops.Arbitrary: lambda x: x.reduction(pandas_kernels.arbitrary), } serieswise = { diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index 5e7a794537011..8bfaccf9dbcbc 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -330,6 +330,36 @@ def visit(cls, op: ops.ArrayCollect, arg, where, order_by, ignore_null): (lambda x: x.dropna().tolist() if ignore_null else x.tolist()), arg, where ) + @classmethod + def visit(cls, op: ops.First, arg, where, order_by, ignore_null): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + + def first(arg): + if ignore_null: + arg = arg.dropna() + return arg.iat[0] if len(arg) else None + + return cls.agg(first, arg, where) + + @classmethod + def visit(cls, op: ops.Last, arg, where, order_by, ignore_null): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + + def last(arg): + if ignore_null: + arg = arg.dropna() + return arg.iat[-1] if len(arg) else None + + return cls.agg(last, arg, where) + @classmethod def visit(cls, op: ops.Correlation, left, right, where, how): if where is None: diff --git a/ibis/backends/pandas/kernels.py b/ibis/backends/pandas/kernels.py index c6feb6c3e8862..0d2bc1db1de82 100644 --- a/ibis/backends/pandas/kernels.py +++ b/ibis/backends/pandas/kernels.py @@ -260,18 +260,11 @@ def round_serieswise(arg, digits): return np.round(arg, digits).astype("float64") -def first(arg): - # first excludes null values unless they're all null +def arbitrary(arg): arg = arg.dropna() return arg.iat[0] if len(arg) else None -def last(arg): - # last excludes null values unless they're all null - arg = arg.dropna() - return arg.iat[-1] if len(arg) else None - - reductions = { ops.Min: lambda x: x.min(), ops.Max: lambda x: x.max(), @@ -286,9 +279,7 @@ def last(arg): ops.BitAnd: lambda x: np.bitwise_and.reduce(x.values), ops.BitOr: lambda x: np.bitwise_or.reduce(x.values), ops.BitXor: lambda x: np.bitwise_xor.reduce(x.values), - ops.Last: last, - ops.First: first, - ops.Arbitrary: first, + ops.Arbitrary: arbitrary, ops.CountDistinct: lambda x: x.nunique(), ops.ApproxCountDistinct: lambda x: x.nunique(), } diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index d0ac2163f7963..4a9c7b9bd303c 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -741,8 +741,7 @@ def execute_reduction(op, **kw): def execute_first_last(op, **kw): arg = translate(op.arg, **kw) - # polars doesn't ignore nulls by default for these methods - predicate = arg.is_not_null() + predicate = arg.is_not_null() if getattr(op, "ignore_null", True) else True if op.where is not None: predicate &= translate(op.where, **kw) diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 418268b8b98c1..74a39d3f9db19 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -330,7 +330,6 @@ class SQLGlotCompiler(abc.ABC): ops.Degrees: "degrees", ops.DenseRank: "dense_rank", ops.Exp: "exp", - ops.First: "first", FirstValue: "first_value", ops.GroupConcat: "group_concat", ops.IfElse: "if", @@ -338,7 +337,6 @@ class SQLGlotCompiler(abc.ABC): ops.IsNan: "isnan", ops.JSONGetItem: "json_extract", ops.LPad: "lpad", - ops.Last: "last", LastValue: "last_value", ops.Levenshtein: "levenshtein", ops.Ln: "ln", diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index b5818e35468c7..adf264699706e 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -441,6 +441,11 @@ def visit_StringToTimestamp(self, op, *, arg, format_str): return self.f.parse_datetime(format_str, arg) def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if where is not None and not ignore_null: + raise com.UnsupportedOperationError( + "Combining `ignore_null=False` and `where` is not supported " + "by bigquery" + ) out = self.agg.array_agg(arg, where=where, order_by=order_by) if ignore_null: out = sge.IgnoreNulls(this=out) @@ -690,26 +695,40 @@ def visit_TimestampRange(self, op, *, start, stop, step): self.f.generate_timestamp_array, start, stop, step, op.step.dtype ) - def visit_First(self, op, *, arg, where, order_by): + def visit_First(self, op, *, arg, where, order_by, ignore_null): if where is not None: arg = self.if_(where, arg, NULL) + if not ignore_null: + raise com.UnsupportedOperationError( + "Combining `ignore_null=False` and `where` is not supported " + "by bigquery" + ) if order_by: arg = sge.Order(this=arg, expressions=order_by) - array = self.f.array_agg( - sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)), - ) + if ignore_null: + arg = sge.IgnoreNulls(this=arg) + + array = self.f.array_agg(sge.Limit(this=arg, expression=sge.convert(1))) return array[self.f.safe_offset(0)] - def visit_Last(self, op, *, arg, where, order_by): + def visit_Last(self, op, *, arg, where, order_by, ignore_null): if where is not None: arg = self.if_(where, arg, NULL) + if not ignore_null: + raise com.UnsupportedOperationError( + "Combining `ignore_null=False` and `where` is not supported " + "by bigquery" + ) if order_by: arg = sge.Order(this=arg, expressions=order_by) - array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg))) + if ignore_null: + arg = sge.IgnoreNulls(this=arg) + + array = self.f.array_reverse(self.f.array_agg(arg)) return array[self.f.safe_offset(0)] def visit_ArrayFilter(self, op, *, arg, body, param): diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index 00236f96df990..2800d05faea63 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -90,13 +90,11 @@ class ClickHouseCompiler(SQLGlotCompiler): ops.ExtractWeekOfYear: "toISOWeek", ops.ExtractYear: "toYear", ops.ExtractIsoYear: "toISOYear", - ops.First: "any", ops.IntegerRange: "range", ops.IsInf: "isInfinite", ops.IsNan: "isNaN", ops.IsNull: "isNull", ops.LStrip: "trimLeft", - ops.Last: "anyLast", ops.Ln: "log", ops.Log10: "log10", ops.MapKeys: "mapKeys", @@ -606,10 +604,24 @@ def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str: def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): if not ignore_null: raise com.UnsupportedOperationError( - "`ignore_null=False` is not supported by the pyspark backend" + "`ignore_null=False` is not supported by the clickhouse backend" ) return self.agg.groupArray(arg, where=where, order_by=order_by) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the clickhouse backend" + ) + return self.agg.any(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the clickhouse backend" + ) + return self.agg.anyLast(arg, where=where, order_by=order_by) + def visit_CountDistinctStar( self, op: ops.CountDistinctStar, *, where, **_: Any ) -> str: diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index acc6dfc224314..8c0e9a29329d6 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -425,14 +425,16 @@ def visit_StringConcat(self, op, *, arg): sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg) ) - def visit_First(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.first_value(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last_value(arg, where=where, order_by=order_by) def visit_Aggregate(self, op, *, parent, groups, metrics): diff --git a/ibis/backends/sql/compilers/druid.py b/ibis/backends/sql/compilers/druid.py index 6479628d9acb8..29cb7dfec7ac3 100644 --- a/ibis/backends/sql/compilers/druid.py +++ b/ibis/backends/sql/compilers/druid.py @@ -42,11 +42,9 @@ class DruidCompiler(SQLGlotCompiler): ops.DateFromYMD, ops.DayOfWeekIndex, ops.DayOfWeekName, - ops.First, ops.IntervalFromInteger, ops.IsNan, ops.IsInf, - ops.Last, ops.Levenshtein, ops.Median, ops.MultiQuantile, diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index 7498ab1f02772..1d0ce89b7c546 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -510,14 +510,16 @@ def visit_RegexReplace(self, op, *, arg, pattern, replacement): arg, pattern, replacement, "g", dialect=self.dialect ) - def visit_First(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.first(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last(arg, where=where, order_by=order_by) def visit_Quantile(self, op, *, arg, quantile, where): diff --git a/ibis/backends/sql/compilers/exasol.py b/ibis/backends/sql/compilers/exasol.py index bdf690a7ff2bc..50eab6297f92b 100644 --- a/ibis/backends/sql/compilers/exasol.py +++ b/ibis/backends/sql/compilers/exasol.py @@ -87,8 +87,6 @@ class ExasolCompiler(SQLGlotCompiler): ops.Log10: "log10", ops.All: "min", ops.Any: "max", - ops.First: "first_value", - ops.Last: "last_value", } @staticmethod @@ -136,6 +134,20 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by): return sge.GroupConcat(this=arg, separator=sep) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the exasol backend" + ) + return self.agg.first_value(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the exasol backend" + ) + return self.agg.last_value(arg, where=where, order_by=order_by) + def visit_StartsWith(self, op, *, arg, start): return self.f.left(arg, self.f.length(start)).eq(start) diff --git a/ibis/backends/sql/compilers/flink.py b/ibis/backends/sql/compilers/flink.py index fb1fad34cfffb..e31f076c86026 100644 --- a/ibis/backends/sql/compilers/flink.py +++ b/ibis/backends/sql/compilers/flink.py @@ -104,8 +104,6 @@ class FlinkCompiler(SQLGlotCompiler): ops.ArrayRemove: "array_remove", ops.ArrayUnion: "array_union", ops.ExtractDayOfYear: "dayofyear", - ops.First: "first_value", - ops.Last: "last_value", ops.MapKeys: "map_keys", ops.MapValues: "map_values", ops.Power: "power", @@ -307,6 +305,20 @@ def visit_ArraySlice(self, op, *, arg, start, stop): return self.f.array_slice(*args) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the flink backend" + ) + return self.agg.first_value(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the flink backend" + ) + return self.agg.last_value(arg, where=where, order_by=order_by) + def visit_Not(self, op, *, arg): return sg.not_(self.cast(arg, dt.boolean)) diff --git a/ibis/backends/sql/compilers/impala.py b/ibis/backends/sql/compilers/impala.py index bae861126d163..8a77ebb2bc94d 100644 --- a/ibis/backends/sql/compilers/impala.py +++ b/ibis/backends/sql/compilers/impala.py @@ -31,8 +31,6 @@ class ImpalaCompiler(SQLGlotCompiler): ops.Covariance, ops.DateDelta, ops.ExtractDayOfYear, - ops.First, - ops.Last, ops.Levenshtein, ops.Map, ops.Median, diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index d7b815252e9ed..e16e969abfc07 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -90,14 +90,12 @@ class MSSQLCompiler(SQLGlotCompiler): ops.DateDiff, ops.DateSub, ops.EndsWith, - ops.First, ops.IntervalAdd, ops.IntervalFromInteger, ops.IntervalMultiply, ops.IntervalSubtract, ops.IsInf, ops.IsNan, - ops.Last, ops.LPad, ops.Levenshtein, ops.Map, diff --git a/ibis/backends/sql/compilers/mysql.py b/ibis/backends/sql/compilers/mysql.py index 56c6f799c89aa..d5367f9495a6c 100644 --- a/ibis/backends/sql/compilers/mysql.py +++ b/ibis/backends/sql/compilers/mysql.py @@ -71,8 +71,6 @@ def POS_INF(self): ops.ArrayFlatten, ops.ArrayMap, ops.Covariance, - ops.First, - ops.Last, ops.Levenshtein, ops.Median, ops.Mode, diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index 98bea4d2d90b4..c22a40d1dea62 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -55,8 +55,6 @@ class OracleCompiler(SQLGlotCompiler): ops.ArrayFlatten, ops.ArrayMap, ops.ArrayStringJoin, - ops.First, - ops.Last, ops.Mode, ops.MultiQuantile, ops.RegexSplit, diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index f145ac7791111..dcdf2b5c0e239 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -361,6 +361,20 @@ def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the postgres backend" + ) + return self.agg.first(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the postgres backend" + ) + return self.agg.last(arg, where=where, order_by=order_by) + def visit_Log2(self, op, *, arg): return self.cast( self.f.log( diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index 9a4c292d5d114..b153e0a64e4aa 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -240,11 +240,27 @@ def visit_FirstValue(self, op, *, arg): def visit_LastValue(self, op, *, arg): return sge.IgnoreNulls(this=self.f.last(arg)) - def visit_First(self, op, *, arg, where, order_by): - return sge.IgnoreNulls(this=self.agg.first(arg, where=where, order_by=order_by)) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if where is not None and not ignore_null: + raise com.UnsupportedOperationError( + "Combining `ignore_null=False` and `where` is not supported " + "by pyspark" + ) + out = self.agg.first(arg, where=where, order_by=order_by) + if ignore_null: + out = sge.IgnoreNulls(this=out) + return out - def visit_Last(self, op, *, arg, where, order_by): - return sge.IgnoreNulls(this=self.agg.last(arg, where=where, order_by=order_by)) + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if where is not None and not ignore_null: + raise com.UnsupportedOperationError( + "Combining `ignore_null=False` and `where` is not supported " + "by pyspark" + ) + out = self.agg.last(arg, where=where, order_by=order_by) + if ignore_null: + out = sge.IgnoreNulls(this=out) + return out def visit_Arbitrary(self, op, *, arg, where): # For Spark>=3.4 we could use any_value here diff --git a/ibis/backends/sql/compilers/risingwave.py b/ibis/backends/sql/compilers/risingwave.py index 35f741c174997..856de11da7f2e 100644 --- a/ibis/backends/sql/compilers/risingwave.py +++ b/ibis/backends/sql/compilers/risingwave.py @@ -35,14 +35,22 @@ class RisingWaveCompiler(PostgresCompiler): def visit_DateNow(self, op): return self.cast(sge.CurrentTimestamp(), dt.date) - def visit_First(self, op, *, arg, where, order_by): + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the risingwave backend" + ) if not order_by: raise com.UnsupportedOperationError( "RisingWave requires an `order_by` be specified in `first`" ) return self.agg.first_value(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where, order_by): + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the risingwave backend" + ) if not order_by: raise com.UnsupportedOperationError( "RisingWave requires an `order_by` be specified in `last`" diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index 75444d2f588a7..7c0d29428f61d 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -455,7 +455,7 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short]) - def _array_collect(self, *, arg, where, order_by, ignore_null=True): + def _array_collect(self, *, arg, where, order_by, ignore_null): if not ignore_null: raise com.UnsupportedOperationError( "`ignore_null=False` is not supported by the snowflake backend" @@ -476,12 +476,16 @@ def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): arg=arg, where=where, order_by=order_by, ignore_null=ignore_null ) - def visit_First(self, op, *, arg, where, order_by): - out = self._array_collect(arg=arg, where=where, order_by=order_by) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + out = self._array_collect( + arg=arg, where=where, order_by=order_by, ignore_null=ignore_null + ) return self.f.get(out, 0) - def visit_Last(self, op, *, arg, where, order_by): - out = self._array_collect(arg=arg, where=where, order_by=order_by) + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + out = self._array_collect( + arg=arg, where=where, order_by=order_by, ignore_null=ignore_null + ) return self.f.get(out, self.f.array_size(out) - 1) def visit_GroupConcat(self, op, *, arg, where, sep, order_by): diff --git a/ibis/backends/sql/compilers/sqlite.py b/ibis/backends/sql/compilers/sqlite.py index 9f0d4bf7c1dcc..fb4edc39f8219 100644 --- a/ibis/backends/sql/compilers/sqlite.py +++ b/ibis/backends/sql/compilers/sqlite.py @@ -67,7 +67,7 @@ class SQLiteCompiler(SQLGlotCompiler): ) SIMPLE_OPS = { - ops.Arbitrary: "_ibis_first", + ops.Arbitrary: "_ibis_first_ignore_null", ops.RegexReplace: "_ibis_regex_replace", ops.RegexExtract: "_ibis_regex_extract", ops.RegexSearch: "_ibis_regex_search", @@ -90,8 +90,6 @@ class SQLiteCompiler(SQLGlotCompiler): ops.BitOr: "_ibis_bit_or", ops.BitAnd: "_ibis_bit_and", ops.BitXor: "_ibis_bit_xor", - ops.First: "_ibis_first", - ops.Last: "_ibis_last", ops.Mode: "_ibis_mode", ops.Time: "time", ops.Date: "date", @@ -249,6 +247,14 @@ def visit_UnwrapJSONBoolean(self, op, *, arg): NULL, ) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + func = "_ibis_first_ignore_null" if ignore_null else "_ibis_first" + return self.agg[func](arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + func = "_ibis_last_ignore_null" if ignore_null else "_ibis_last" + return self.agg[func](arg, where=where, order_by=order_by) + def visit_Variance(self, op, *, arg, how, where): return self.agg[f"_ibis_var_{op.how}"](arg, where=where) diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index f35653c79712a..25cf429d359d2 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -373,16 +373,18 @@ def visit_StringAscii(self, op, *, arg): def visit_ArrayStringJoin(self, op, *, sep, arg): return self.f.array_join(arg, sep) - def visit_First(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.f.element_at( self.agg.array_agg(arg, where=where, order_by=order_by), 1 ) - def visit_Last(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.f.element_at( self.agg.array_agg(arg, where=where, order_by=order_by), -1 ) diff --git a/ibis/backends/sqlite/udf.py b/ibis/backends/sqlite/udf.py index 49c775e9e8e7d..b2fa16fcca69a 100644 --- a/ibis/backends/sqlite/udf.py +++ b/ibis/backends/sqlite/udf.py @@ -1,6 +1,5 @@ from __future__ import annotations -import abc import functools import inspect import math @@ -31,6 +30,7 @@ class _UDF(NamedTuple): _SQLITE_UDF_REGISTRY = {} _SQLITE_UDAF_REGISTRY = {} +UNSET = object() def ignore_nulls(f): @@ -441,31 +441,44 @@ def __init__(self): super().__init__(operator.xor) -class _ibis_first_last(abc.ABC): - def __init__(self) -> None: +class _ibis_first_last: + def __init__(self): self.value = None - @abc.abstractmethod - def step(self, value): ... - - def finalize(self) -> int | None: + def finalize(self): return self.value @udaf -class _ibis_first(_ibis_first_last): +class _ibis_first_ignore_null(_ibis_first_last): def step(self, value): if self.value is None: self.value = value @udaf -class _ibis_last(_ibis_first_last): +class _ibis_first(_ibis_first_last): + def __init__(self): + self.value = UNSET + + def step(self, value): + if self.value is UNSET: + self.value = value + + +@udaf +class _ibis_last_ignore_null(_ibis_first_last): def step(self, value): if value is not None: self.value = value +@udaf +class _ibis_last(_ibis_first_last): + def step(self, value): + self.value = value + + def register_all(con): """Register all udf and udaf with the connection. diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 15d7cbcb6018b..6353d68a72101 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -616,13 +616,43 @@ def test_reduction_ops( ["datafusion"], raises=Exception, reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + strict=False, ) ], ), True, ], ) -def test_first_last(backend, alltypes, method, filtered): +@pytest.mark.parametrize( + "ignore_null", + [ + True, + param( + False, + marks=[ + pytest.mark.notimpl( + [ + "clickhouse", + "exasol", + "flink", + "postgres", + "risingwave", + "snowflake", + ], + raises=com.UnsupportedOperationError, + reason="`ignore_null=False` is not supported", + ), + pytest.mark.notimpl( + ["bigquery", "pyspark"], + raises=com.UnsupportedOperationError, + reason="Can't mix `where` and `ignore_null=False`", + strict=False, + ), + ], + ), + ], +) +def test_first_last(backend, alltypes, method, filtered, ignore_null): # `first` and `last` effectively choose an arbitrary value when no # additional order is specified. *Most* backends will result in the # first/last element in a column being selected (at least when operating on @@ -640,9 +670,13 @@ def test_first_last(backend, alltypes, method, filtered): t = alltypes.mutate(new=new) - expr = getattr(t.new, method)(where=where) + expr = getattr(t.new, method)(where=where, ignore_null=ignore_null) res = expr.execute() - assert res == 30 + if ignore_null: + assert res == 30 + else: + # no ordering, so technically could be any element + assert res == 30 or pd.isna(res) @pytest.mark.notimpl( @@ -672,23 +706,59 @@ def test_first_last(backend, alltypes, method, filtered): ["datafusion"], raises=Exception, reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + strict=False, ) ], ), True, ], ) -def test_first_last_ordered(backend, alltypes, method, filtered): +@pytest.mark.parametrize( + "ignore_null", + [ + True, + param( + False, + marks=[ + pytest.mark.notimpl( + [ + "clickhouse", + "exasol", + "flink", + "postgres", + "risingwave", + "snowflake", + ], + raises=com.UnsupportedOperationError, + reason="`ignore_null=False` is not supported", + ), + pytest.mark.notimpl( + ["bigquery", "pyspark"], + raises=com.UnsupportedOperationError, + reason="Can't mix `where` and `ignore_null=False`", + strict=False, + ), + ], + ), + ], +) +def test_first_last_ordered(backend, alltypes, method, filtered, ignore_null): t = alltypes.mutate(new=alltypes.int_col.nullif(0).nullif(9)) - where = None - sol = 1 if method == "last" else 8 if filtered: - where = _.int_col != sol + where = _.int_col != (1 if method == "last" else 8) sol = 2 if method == "last" else 7 + else: + where = None + sol = 1 if method == "last" else 8 - expr = getattr(t.new, method)(where=where, order_by=t.int_col.desc()) + expr = getattr(t.new, method)( + where=where, order_by=t.int_col.desc(), ignore_null=ignore_null + ) res = expr.execute() - assert res == sol + if ignore_null: + assert res == sol + else: + assert pd.isna(res) @pytest.mark.notimpl( @@ -1409,7 +1479,13 @@ def test_collect_ordered(alltypes, df, filtered): ["clickhouse", "pyspark", "snowflake"], raises=com.UnsupportedOperationError, reason="`ignore_null=False` is not supported", - ) + ), + pytest.mark.notimpl( + ["bigquery"], + raises=com.UnsupportedOperationError, + reason="Can't mix `where` and `ignore_null=False`", + strict=False, + ), ], ), ], diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index f5b1128a8694f..01cbee4a0c38c 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -81,6 +81,7 @@ class First(Filterable, Reduction): arg: Column[dt.Any] order_by: VarTuple[SortKey] = () + ignore_null: bool = True dtype = rlz.dtype_like("arg") @@ -91,6 +92,7 @@ class Last(Filterable, Reduction): arg: Column[dt.Any] order_by: VarTuple[SortKey] = () + ignore_null: bool = True dtype = rlz.dtype_like("arg") diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 9cb19b73c5b19..17fab9ed8f798 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -2116,7 +2116,10 @@ def value_counts(self) -> ir.Table: return self.as_table().group_by(name).aggregate(metric) def first( - self, where: ir.BooleanValue | None = None, order_by: Any = None + self, + where: ir.BooleanValue | None = None, + order_by: Any = None, + ignore_null: bool = True, ) -> Value: """Return the first value of a column. @@ -2129,6 +2132,9 @@ def first( An ordering key (or keys) to use to order the rows before aggregating. If not provided, the meaning of `first` is undefined and will be backend specific. + ignore_null + Whether to ignore null values when performing this aggregation. Set + to `False` to include nulls in the result. Examples -------- @@ -2159,9 +2165,15 @@ def first( self, where=self._bind_to_parent_table(where), order_by=self._bind_order_by(order_by), + ignore_null=ignore_null, ).to_expr() - def last(self, where: ir.BooleanValue | None = None, order_by: Any = None) -> Value: + def last( + self, + where: ir.BooleanValue | None = None, + order_by: Any = None, + ignore_null: bool = True, + ) -> Value: """Return the last value of a column. Parameters @@ -2173,6 +2185,9 @@ def last(self, where: ir.BooleanValue | None = None, order_by: Any = None) -> Va An ordering key (or keys) to use to order the rows before aggregating. If not provided, the meaning of `last` is undefined and will be backend specific. + ignore_null + Whether to ignore null values when performing this aggregation. Set + to `False` to include nulls in the result. Examples -------- @@ -2203,6 +2218,7 @@ def last(self, where: ir.BooleanValue | None = None, order_by: Any = None) -> Va self, where=self._bind_to_parent_table(where), order_by=self._bind_order_by(order_by), + ignore_null=ignore_null, ).to_expr() def rank(self) -> ir.IntegerColumn: