diff --git a/ibis/backends/dask/kernels.py b/ibis/backends/dask/kernels.py index 62d7ed2f8c6ce..12a1a782ab01d 100644 --- a/ibis/backends/dask/kernels.py +++ b/ibis/backends/dask/kernels.py @@ -16,6 +16,14 @@ ops.DateAdd: lambda row: row["left"] + row["right"], } + +def maybe_pandas_reduction(func): + def inner(df): + return df.reduction(func) if isinstance(df, dd.Series) else func(df) + + return inner + + reductions = { **pandas_kernels.reductions, ops.Mode: lambda x: x.mode().loc[0], @@ -23,10 +31,10 @@ ops.BitAnd: lambda x: x.reduction(np.bitwise_and.reduce), ops.BitOr: lambda x: x.reduction(np.bitwise_or.reduce), ops.BitXor: lambda x: x.reduction(np.bitwise_xor.reduce), + ops.Arbitrary: lambda x: x.reduction(pandas_kernels.first), # Window functions are calculated locally using pandas - ops.Last: lambda x: x.compute().iloc[-1] if isinstance(x, dd.Series) else x.iat[-1], - ops.First: lambda x: x.loc[0] if isinstance(x, dd.Series) else x.iat[0], - ops.Arbitrary: lambda x: x.reduction(pandas_kernels.arbitrary), + ops.Last: maybe_pandas_reduction(pandas_kernels.last), + ops.First: maybe_pandas_reduction(pandas_kernels.first), } serieswise = { diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index ceaebed4684d0..68e9d69c0df57 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -14,7 +14,10 @@ from ibis.backends.sql.compiler import FALSE, NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DataFusionType from ibis.backends.sql.dialects import DataFusion -from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect +from ibis.backends.sql.rewrites import ( + exclude_nulls_from_array_collect, + exclude_nulls_from_first_last, +) from ibis.common.temporal import IntervalUnit, TimestampUnit from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowType @@ -26,7 +29,11 @@ class DataFusionCompiler(SQLGlotCompiler): dialect = DataFusion type_mapper = DataFusionType - rewrites = (exclude_nulls_from_array_collect, *SQLGlotCompiler.rewrites) + rewrites = ( + exclude_nulls_from_array_collect, + exclude_nulls_from_first_last, + *SQLGlotCompiler.rewrites, + ) agg = AggGen(supports_filter=True) diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index a387d4c342243..f4256db7bc59d 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -13,7 +13,10 @@ import ibis.expr.operations as ops from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DuckDBType -from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect +from ibis.backends.sql.rewrites import ( + exclude_nulls_from_array_collect, + exclude_nulls_from_first_last, +) _INTERVAL_SUFFIXES = { "ms": "milliseconds", @@ -38,6 +41,7 @@ class DuckDBCompiler(SQLGlotCompiler): rewrites = ( exclude_nulls_from_array_collect, + exclude_nulls_from_first_last, *SQLGlotCompiler.rewrites, ) diff --git a/ibis/backends/exasol/compiler.py b/ibis/backends/exasol/compiler.py index 7a837e41b9c8f..c752f9f9068d1 100644 --- a/ibis/backends/exasol/compiler.py +++ b/ibis/backends/exasol/compiler.py @@ -53,11 +53,9 @@ class ExasolCompiler(SQLGlotCompiler): ops.DateFromYMD, ops.DayOfWeekIndex, ops.ElementWiseVectorizedUDF, - ops.First, ops.IntervalFromInteger, ops.IsInf, ops.IsNan, - ops.Last, ops.Levenshtein, ops.Median, ops.MultiQuantile, @@ -90,6 +88,8 @@ class ExasolCompiler(SQLGlotCompiler): ops.Log10: "log10", ops.All: "min", ops.Any: "max", + ops.First: "first_value", + ops.Last: "last_value", } @staticmethod diff --git a/ibis/backends/pandas/kernels.py b/ibis/backends/pandas/kernels.py index 90afc814078a4..da650d1211c3a 100644 --- a/ibis/backends/pandas/kernels.py +++ b/ibis/backends/pandas/kernels.py @@ -254,12 +254,18 @@ def round_serieswise(arg, digits): return np.round(arg, digits).astype("float64") -def arbitrary(arg): - # Arbitrary excludes null values unless they're all null +def first(arg): + # first excludes null values unless they're all null arg = arg.dropna() return arg.iat[0] if len(arg) else None +def last(arg): + # last excludes null values unless they're all null + arg = arg.dropna() + return arg.iat[-1] if len(arg) else None + + reductions = { ops.Min: lambda x: x.min(), ops.Max: lambda x: x.max(), @@ -274,9 +280,9 @@ def arbitrary(arg): ops.BitAnd: lambda x: np.bitwise_and.reduce(x.values), ops.BitOr: lambda x: np.bitwise_or.reduce(x.values), ops.BitXor: lambda x: np.bitwise_xor.reduce(x.values), - ops.Last: lambda x: x.iat[-1], - ops.First: lambda x: x.iat[0], - ops.Arbitrary: arbitrary, + ops.Last: last, + ops.First: first, + ops.Arbitrary: first, ops.CountDistinct: lambda x: x.nunique(), ops.ApproxCountDistinct: lambda x: x.nunique(), ops.ArrayCollect: lambda x: x.dropna().tolist(), diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index 19307cd3a129b..0dff771d3bc80 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -382,6 +382,14 @@ def exclude_nulls_from_array_collect(_, **kwargs): return _.copy(where=where) +@replace(p.First | p.Last) +def exclude_nulls_from_first_last(_, **kwargs): + where = ops.NotNull(_.arg) + if _.where is not None: + where = ops.And(where, _.where) + return _.copy(where=where) + + # Rewrite rules for lowering a high-level operation into one composed of more # primitive operations. diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 4a48bdb2e77df..28dbacc647716 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -486,30 +486,6 @@ def mean_and_std(v): pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError), ], ), - param( - lambda t, where: t.double_col.first(where=where), - lambda t, where: t.double_col[where].iloc[0], - id="first", - marks=[ - pytest.mark.notimpl( - ["druid", "impala", "mssql", "mysql", "oracle"], - raises=com.OperationNotDefinedError, - ), - pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError), - ], - ), - param( - lambda t, where: t.double_col.last(where=where), - lambda t, where: t.double_col[where].iloc[-1], - id="last", - marks=[ - pytest.mark.notimpl( - ["druid", "impala", "mssql", "mysql", "oracle"], - raises=com.OperationNotDefinedError, - ), - pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError), - ], - ), param( lambda t, where: t.bigint_col.bit_and(where=where), lambda t, where: np.bitwise_and.reduce(t.bigint_col[where].values), @@ -641,6 +617,36 @@ def test_reduction_ops( np.testing.assert_array_equal(result, expected) +@pytest.mark.notimpl( + ["druid", "impala", "mssql", "mysql", "oracle"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError) +@pytest.mark.parametrize("method", ["first", "last"]) +@pytest.mark.parametrize("filtered", [False, True]) +def test_first_last(backend, alltypes, method, filtered): + # `first` and `last` effectively choose an arbitrary value when no + # additional order is specified. *Most* backends will result in the + # first/last element in a column being selected (at least when operating on + # a leaf table), but that's really not guaranteed. These operations need an + # order to be meaningful. + # + # To sanely test this we create a column that is a mix of nulls and a + # single value (or a single value after filtering is applied). + if filtered: + new = alltypes.int_col.cases([(3, 30), (4, 40)]) + where = _.int_col == 3 + else: + new = (alltypes.int_col == 3).ifelse(30, None) + where = None + + t = alltypes.mutate(new=new) + + expr = getattr(t.new, method)(where=where) + res = expr.execute() + assert res == 30 + + @pytest.mark.notimpl( [ "impala", diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index ac6f1c69484f0..4a460d4a0227c 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -15,6 +15,7 @@ from ibis.backends.sql.dialects import Trino from ibis.backends.sql.rewrites import ( exclude_nulls_from_array_collect, + exclude_nulls_from_first_last, exclude_unsupported_window_frame_from_ops, ) @@ -29,6 +30,7 @@ class TrinoCompiler(SQLGlotCompiler): rewrites = ( exclude_nulls_from_array_collect, + exclude_nulls_from_first_last, exclude_unsupported_window_frame_from_ops, *SQLGlotCompiler.rewrites, )