From 8d10a3e283297596d315fd88b2df75d824a21538 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Tue, 6 Aug 2024 11:42:10 -0500 Subject: [PATCH] feat(api): support `ignore_null` in `first`/`last` --- ibis/backends/sql/compilers/base.py | 2 - .../sql/compilers/bigquery/__init__.py | 16 ++-- ibis/backends/sql/compilers/clickhouse.py | 18 ++++- ibis/backends/sql/compilers/datafusion.py | 14 ++-- ibis/backends/sql/compilers/druid.py | 2 - ibis/backends/sql/compilers/duckdb.py | 14 ++-- ibis/backends/sql/compilers/exasol.py | 16 +++- ibis/backends/sql/compilers/flink.py | 16 +++- ibis/backends/sql/compilers/impala.py | 2 - ibis/backends/sql/compilers/mssql.py | 2 - ibis/backends/sql/compilers/mysql.py | 2 - ibis/backends/sql/compilers/oracle.py | 2 - ibis/backends/sql/compilers/postgres.py | 14 ++++ ibis/backends/sql/compilers/pyspark.py | 16 ++-- ibis/backends/sql/compilers/risingwave.py | 12 ++- ibis/backends/sql/compilers/snowflake.py | 14 ++-- ibis/backends/sql/compilers/sqlite.py | 10 ++- ibis/backends/sql/compilers/trino.py | 14 ++-- ibis/backends/sqlite/udf.py | 31 +++++--- ibis/backends/tests/test_aggregation.py | 76 ++++++++++++++++--- ibis/expr/operations/reductions.py | 2 + ibis/expr/types/generic.py | 20 ++++- 22 files changed, 238 insertions(+), 77 deletions(-) diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 418268b8b98c1..74a39d3f9db19 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -330,7 +330,6 @@ class SQLGlotCompiler(abc.ABC): ops.Degrees: "degrees", ops.DenseRank: "dense_rank", ops.Exp: "exp", - ops.First: "first", FirstValue: "first_value", ops.GroupConcat: "group_concat", ops.IfElse: "if", @@ -338,7 +337,6 @@ class SQLGlotCompiler(abc.ABC): ops.IsNan: "isnan", ops.JSONGetItem: "json_extract", ops.LPad: "lpad", - ops.Last: "last", LastValue: "last_value", ops.Levenshtein: "levenshtein", ops.Ln: "ln", diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index b5818e35468c7..8c19940e8e5a3 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -690,26 +690,30 @@ def visit_TimestampRange(self, op, *, start, stop, step): self.f.generate_timestamp_array, start, stop, step, op.step.dtype ) - def visit_First(self, op, *, arg, where, order_by): + def visit_First(self, op, *, arg, where, order_by, ignore_null): if where is not None: arg = self.if_(where, arg, NULL) if order_by: arg = sge.Order(this=arg, expressions=order_by) - array = self.f.array_agg( - sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)), - ) + if ignore_null: + arg = sge.IgnoreNulls(this=arg) + + array = self.f.array_agg(sge.Limit(this=arg, expression=sge.convert(1))) return array[self.f.safe_offset(0)] - def visit_Last(self, op, *, arg, where, order_by): + def visit_Last(self, op, *, arg, where, order_by, ignore_null): if where is not None: arg = self.if_(where, arg, NULL) if order_by: arg = sge.Order(this=arg, expressions=order_by) - array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg))) + if ignore_null: + arg = sge.IgnoreNulls(this=arg) + + array = self.f.array_reverse(self.f.array_agg(arg)) return array[self.f.safe_offset(0)] def visit_ArrayFilter(self, op, *, arg, body, param): diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index 00236f96df990..2800d05faea63 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -90,13 +90,11 @@ class ClickHouseCompiler(SQLGlotCompiler): ops.ExtractWeekOfYear: "toISOWeek", ops.ExtractYear: "toYear", ops.ExtractIsoYear: "toISOYear", - ops.First: "any", ops.IntegerRange: "range", ops.IsInf: "isInfinite", ops.IsNan: "isNaN", ops.IsNull: "isNull", ops.LStrip: "trimLeft", - ops.Last: "anyLast", ops.Ln: "log", ops.Log10: "log10", ops.MapKeys: "mapKeys", @@ -606,10 +604,24 @@ def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str: def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): if not ignore_null: raise com.UnsupportedOperationError( - "`ignore_null=False` is not supported by the pyspark backend" + "`ignore_null=False` is not supported by the clickhouse backend" ) return self.agg.groupArray(arg, where=where, order_by=order_by) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the clickhouse backend" + ) + return self.agg.any(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the clickhouse backend" + ) + return self.agg.anyLast(arg, where=where, order_by=order_by) + def visit_CountDistinctStar( self, op: ops.CountDistinctStar, *, where, **_: Any ) -> str: diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index acc6dfc224314..8c0e9a29329d6 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -425,14 +425,16 @@ def visit_StringConcat(self, op, *, arg): sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg) ) - def visit_First(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.first_value(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last_value(arg, where=where, order_by=order_by) def visit_Aggregate(self, op, *, parent, groups, metrics): diff --git a/ibis/backends/sql/compilers/druid.py b/ibis/backends/sql/compilers/druid.py index 6479628d9acb8..29cb7dfec7ac3 100644 --- a/ibis/backends/sql/compilers/druid.py +++ b/ibis/backends/sql/compilers/druid.py @@ -42,11 +42,9 @@ class DruidCompiler(SQLGlotCompiler): ops.DateFromYMD, ops.DayOfWeekIndex, ops.DayOfWeekName, - ops.First, ops.IntervalFromInteger, ops.IsNan, ops.IsInf, - ops.Last, ops.Levenshtein, ops.Median, ops.MultiQuantile, diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index 7498ab1f02772..1d0ce89b7c546 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -510,14 +510,16 @@ def visit_RegexReplace(self, op, *, arg, pattern, replacement): arg, pattern, replacement, "g", dialect=self.dialect ) - def visit_First(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.first(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last(arg, where=where, order_by=order_by) def visit_Quantile(self, op, *, arg, quantile, where): diff --git a/ibis/backends/sql/compilers/exasol.py b/ibis/backends/sql/compilers/exasol.py index bdf690a7ff2bc..50eab6297f92b 100644 --- a/ibis/backends/sql/compilers/exasol.py +++ b/ibis/backends/sql/compilers/exasol.py @@ -87,8 +87,6 @@ class ExasolCompiler(SQLGlotCompiler): ops.Log10: "log10", ops.All: "min", ops.Any: "max", - ops.First: "first_value", - ops.Last: "last_value", } @staticmethod @@ -136,6 +134,20 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by): return sge.GroupConcat(this=arg, separator=sep) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the exasol backend" + ) + return self.agg.first_value(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the exasol backend" + ) + return self.agg.last_value(arg, where=where, order_by=order_by) + def visit_StartsWith(self, op, *, arg, start): return self.f.left(arg, self.f.length(start)).eq(start) diff --git a/ibis/backends/sql/compilers/flink.py b/ibis/backends/sql/compilers/flink.py index fb1fad34cfffb..e31f076c86026 100644 --- a/ibis/backends/sql/compilers/flink.py +++ b/ibis/backends/sql/compilers/flink.py @@ -104,8 +104,6 @@ class FlinkCompiler(SQLGlotCompiler): ops.ArrayRemove: "array_remove", ops.ArrayUnion: "array_union", ops.ExtractDayOfYear: "dayofyear", - ops.First: "first_value", - ops.Last: "last_value", ops.MapKeys: "map_keys", ops.MapValues: "map_values", ops.Power: "power", @@ -307,6 +305,20 @@ def visit_ArraySlice(self, op, *, arg, start, stop): return self.f.array_slice(*args) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the flink backend" + ) + return self.agg.first_value(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the flink backend" + ) + return self.agg.last_value(arg, where=where, order_by=order_by) + def visit_Not(self, op, *, arg): return sg.not_(self.cast(arg, dt.boolean)) diff --git a/ibis/backends/sql/compilers/impala.py b/ibis/backends/sql/compilers/impala.py index bae861126d163..8a77ebb2bc94d 100644 --- a/ibis/backends/sql/compilers/impala.py +++ b/ibis/backends/sql/compilers/impala.py @@ -31,8 +31,6 @@ class ImpalaCompiler(SQLGlotCompiler): ops.Covariance, ops.DateDelta, ops.ExtractDayOfYear, - ops.First, - ops.Last, ops.Levenshtein, ops.Map, ops.Median, diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index d7b815252e9ed..e16e969abfc07 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -90,14 +90,12 @@ class MSSQLCompiler(SQLGlotCompiler): ops.DateDiff, ops.DateSub, ops.EndsWith, - ops.First, ops.IntervalAdd, ops.IntervalFromInteger, ops.IntervalMultiply, ops.IntervalSubtract, ops.IsInf, ops.IsNan, - ops.Last, ops.LPad, ops.Levenshtein, ops.Map, diff --git a/ibis/backends/sql/compilers/mysql.py b/ibis/backends/sql/compilers/mysql.py index 56c6f799c89aa..d5367f9495a6c 100644 --- a/ibis/backends/sql/compilers/mysql.py +++ b/ibis/backends/sql/compilers/mysql.py @@ -71,8 +71,6 @@ def POS_INF(self): ops.ArrayFlatten, ops.ArrayMap, ops.Covariance, - ops.First, - ops.Last, ops.Levenshtein, ops.Median, ops.Mode, diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index 98bea4d2d90b4..c22a40d1dea62 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -55,8 +55,6 @@ class OracleCompiler(SQLGlotCompiler): ops.ArrayFlatten, ops.ArrayMap, ops.ArrayStringJoin, - ops.First, - ops.Last, ops.Mode, ops.MultiQuantile, ops.RegexSplit, diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index f145ac7791111..dcdf2b5c0e239 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -361,6 +361,20 @@ def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the postgres backend" + ) + return self.agg.first(arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the postgres backend" + ) + return self.agg.last(arg, where=where, order_by=order_by) + def visit_Log2(self, op, *, arg): return self.cast( self.f.log( diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index 9a4c292d5d114..f7900a4f3add6 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -240,11 +240,17 @@ def visit_FirstValue(self, op, *, arg): def visit_LastValue(self, op, *, arg): return sge.IgnoreNulls(this=self.f.last(arg)) - def visit_First(self, op, *, arg, where, order_by): - return sge.IgnoreNulls(this=self.agg.first(arg, where=where, order_by=order_by)) - - def visit_Last(self, op, *, arg, where, order_by): - return sge.IgnoreNulls(this=self.agg.last(arg, where=where, order_by=order_by)) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + out = self.agg.first(arg, where=where, order_by=order_by) + if ignore_null: + out = sge.IgnoreNulls(this=out) + return out + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + out = self.agg.last(arg, where=where, order_by=order_by) + if ignore_null: + out = sge.IgnoreNulls(this=out) + return out def visit_Arbitrary(self, op, *, arg, where): # For Spark>=3.4 we could use any_value here diff --git a/ibis/backends/sql/compilers/risingwave.py b/ibis/backends/sql/compilers/risingwave.py index 35f741c174997..856de11da7f2e 100644 --- a/ibis/backends/sql/compilers/risingwave.py +++ b/ibis/backends/sql/compilers/risingwave.py @@ -35,14 +35,22 @@ class RisingWaveCompiler(PostgresCompiler): def visit_DateNow(self, op): return self.cast(sge.CurrentTimestamp(), dt.date) - def visit_First(self, op, *, arg, where, order_by): + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the risingwave backend" + ) if not order_by: raise com.UnsupportedOperationError( "RisingWave requires an `order_by` be specified in `first`" ) return self.agg.first_value(arg, where=where, order_by=order_by) - def visit_Last(self, op, *, arg, where, order_by): + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the risingwave backend" + ) if not order_by: raise com.UnsupportedOperationError( "RisingWave requires an `order_by` be specified in `last`" diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index 75444d2f588a7..7c0d29428f61d 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -455,7 +455,7 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short]) - def _array_collect(self, *, arg, where, order_by, ignore_null=True): + def _array_collect(self, *, arg, where, order_by, ignore_null): if not ignore_null: raise com.UnsupportedOperationError( "`ignore_null=False` is not supported by the snowflake backend" @@ -476,12 +476,16 @@ def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): arg=arg, where=where, order_by=order_by, ignore_null=ignore_null ) - def visit_First(self, op, *, arg, where, order_by): - out = self._array_collect(arg=arg, where=where, order_by=order_by) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + out = self._array_collect( + arg=arg, where=where, order_by=order_by, ignore_null=ignore_null + ) return self.f.get(out, 0) - def visit_Last(self, op, *, arg, where, order_by): - out = self._array_collect(arg=arg, where=where, order_by=order_by) + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + out = self._array_collect( + arg=arg, where=where, order_by=order_by, ignore_null=ignore_null + ) return self.f.get(out, self.f.array_size(out) - 1) def visit_GroupConcat(self, op, *, arg, where, sep, order_by): diff --git a/ibis/backends/sql/compilers/sqlite.py b/ibis/backends/sql/compilers/sqlite.py index 9f0d4bf7c1dcc..99d729fab6a1a 100644 --- a/ibis/backends/sql/compilers/sqlite.py +++ b/ibis/backends/sql/compilers/sqlite.py @@ -90,8 +90,6 @@ class SQLiteCompiler(SQLGlotCompiler): ops.BitOr: "_ibis_bit_or", ops.BitAnd: "_ibis_bit_and", ops.BitXor: "_ibis_bit_xor", - ops.First: "_ibis_first", - ops.Last: "_ibis_last", ops.Mode: "_ibis_mode", ops.Time: "time", ops.Date: "date", @@ -249,6 +247,14 @@ def visit_UnwrapJSONBoolean(self, op, *, arg): NULL, ) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + func = "_ibis_first_ignore_null" if ignore_null else "_ibis_first" + return self.agg[func](arg, where=where, order_by=order_by) + + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + func = "_ibis_last_ignore_null" if ignore_null else "_ibis_last" + return self.agg[func](arg, where=where, order_by=order_by) + def visit_Variance(self, op, *, arg, how, where): return self.agg[f"_ibis_var_{op.how}"](arg, where=where) diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index f35653c79712a..25cf429d359d2 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -373,16 +373,18 @@ def visit_StringAscii(self, op, *, arg): def visit_ArrayStringJoin(self, op, *, sep, arg): return self.f.array_join(arg, sep) - def visit_First(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_First(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.f.element_at( self.agg.array_agg(arg, where=where, order_by=order_by), 1 ) - def visit_Last(self, op, *, arg, where, order_by): - cond = arg.is_(sg.not_(NULL, copy=False)) - where = cond if where is None else sge.And(this=cond, expression=where) + def visit_Last(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.f.element_at( self.agg.array_agg(arg, where=where, order_by=order_by), -1 ) diff --git a/ibis/backends/sqlite/udf.py b/ibis/backends/sqlite/udf.py index 49c775e9e8e7d..b2fa16fcca69a 100644 --- a/ibis/backends/sqlite/udf.py +++ b/ibis/backends/sqlite/udf.py @@ -1,6 +1,5 @@ from __future__ import annotations -import abc import functools import inspect import math @@ -31,6 +30,7 @@ class _UDF(NamedTuple): _SQLITE_UDF_REGISTRY = {} _SQLITE_UDAF_REGISTRY = {} +UNSET = object() def ignore_nulls(f): @@ -441,31 +441,44 @@ def __init__(self): super().__init__(operator.xor) -class _ibis_first_last(abc.ABC): - def __init__(self) -> None: +class _ibis_first_last: + def __init__(self): self.value = None - @abc.abstractmethod - def step(self, value): ... - - def finalize(self) -> int | None: + def finalize(self): return self.value @udaf -class _ibis_first(_ibis_first_last): +class _ibis_first_ignore_null(_ibis_first_last): def step(self, value): if self.value is None: self.value = value @udaf -class _ibis_last(_ibis_first_last): +class _ibis_first(_ibis_first_last): + def __init__(self): + self.value = UNSET + + def step(self, value): + if self.value is UNSET: + self.value = value + + +@udaf +class _ibis_last_ignore_null(_ibis_first_last): def step(self, value): if value is not None: self.value = value +@udaf +class _ibis_last(_ibis_first_last): + def step(self, value): + self.value = value + + def register_all(con): """Register all udf and udaf with the connection. diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index d89c28d4c3dad..94c106673ba1e 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -616,13 +616,37 @@ def test_reduction_ops( ["datafusion"], raises=Exception, reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + strict=False, ) ], ), True, ], ) -def test_first_last(backend, alltypes, method, filtered): +@pytest.mark.parametrize( + "ignore_null", + [ + True, + param( + False, + marks=[ + pytest.mark.notimpl( + [ + "clickhouse", + "exasol", + "flink", + "postgres", + "risingwave", + "snowflake", + ], + raises=com.UnsupportedOperationError, + reason="`ignore_null=False` is not supported", + ) + ], + ), + ], +) +def test_first_last(backend, alltypes, method, filtered, ignore_null): # `first` and `last` effectively choose an arbitrary value when no # additional order is specified. *Most* backends will result in the # first/last element in a column being selected (at least when operating on @@ -640,9 +664,13 @@ def test_first_last(backend, alltypes, method, filtered): t = alltypes.mutate(new=new) - expr = getattr(t.new, method)(where=where) + expr = getattr(t.new, method)(where=where, ignore_null=ignore_null) res = expr.execute() - assert res == 30 + if ignore_null: + assert res == 30 + else: + # no ordering, so technically could be any element + assert res == 30 or pd.isna(res) @pytest.mark.notimpl( @@ -672,23 +700,53 @@ def test_first_last(backend, alltypes, method, filtered): ["datafusion"], raises=Exception, reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + strict=False, ) ], ), True, ], ) -def test_first_last_ordered(backend, alltypes, method, filtered): +@pytest.mark.parametrize( + "ignore_null", + [ + True, + param( + False, + marks=[ + pytest.mark.notimpl( + [ + "clickhouse", + "exasol", + "flink", + "postgres", + "risingwave", + "snowflake", + ], + raises=com.UnsupportedOperationError, + reason="`ignore_null=False` is not supported", + ) + ], + ), + ], +) +def test_first_last_ordered(backend, alltypes, method, filtered, ignore_null): t = alltypes.mutate(new=alltypes.int_col.nullif(0).nullif(9)) - where = None - sol = 1 if method == "last" else 8 if filtered: - where = _.int_col != sol + where = _.int_col != (1 if method == "last" else 8) sol = 2 if method == "last" else 7 + else: + where = None + sol = 1 if method == "last" else 8 - expr = getattr(t.new, method)(where=where, order_by=t.int_col.desc()) + expr = getattr(t.new, method)( + where=where, order_by=t.int_col.desc(), ignore_null=ignore_null + ) res = expr.execute() - assert res == sol + if ignore_null: + assert res == sol + else: + assert pd.isna(res) @pytest.mark.notimpl( diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index f5b1128a8694f..01cbee4a0c38c 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -81,6 +81,7 @@ class First(Filterable, Reduction): arg: Column[dt.Any] order_by: VarTuple[SortKey] = () + ignore_null: bool = True dtype = rlz.dtype_like("arg") @@ -91,6 +92,7 @@ class Last(Filterable, Reduction): arg: Column[dt.Any] order_by: VarTuple[SortKey] = () + ignore_null: bool = True dtype = rlz.dtype_like("arg") diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 9cb19b73c5b19..17fab9ed8f798 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -2116,7 +2116,10 @@ def value_counts(self) -> ir.Table: return self.as_table().group_by(name).aggregate(metric) def first( - self, where: ir.BooleanValue | None = None, order_by: Any = None + self, + where: ir.BooleanValue | None = None, + order_by: Any = None, + ignore_null: bool = True, ) -> Value: """Return the first value of a column. @@ -2129,6 +2132,9 @@ def first( An ordering key (or keys) to use to order the rows before aggregating. If not provided, the meaning of `first` is undefined and will be backend specific. + ignore_null + Whether to ignore null values when performing this aggregation. Set + to `False` to include nulls in the result. Examples -------- @@ -2159,9 +2165,15 @@ def first( self, where=self._bind_to_parent_table(where), order_by=self._bind_order_by(order_by), + ignore_null=ignore_null, ).to_expr() - def last(self, where: ir.BooleanValue | None = None, order_by: Any = None) -> Value: + def last( + self, + where: ir.BooleanValue | None = None, + order_by: Any = None, + ignore_null: bool = True, + ) -> Value: """Return the last value of a column. Parameters @@ -2173,6 +2185,9 @@ def last(self, where: ir.BooleanValue | None = None, order_by: Any = None) -> Va An ordering key (or keys) to use to order the rows before aggregating. If not provided, the meaning of `last` is undefined and will be backend specific. + ignore_null + Whether to ignore null values when performing this aggregation. Set + to `False` to include nulls in the result. Examples -------- @@ -2203,6 +2218,7 @@ def last(self, where: ir.BooleanValue | None = None, order_by: Any = None) -> Va self, where=self._bind_to_parent_table(where), order_by=self._bind_order_by(order_by), + ignore_null=ignore_null, ).to_expr() def rank(self) -> ir.IntegerColumn: