Skip to content

Commit

Permalink
fix: exclude null values from first and last aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Jun 11, 2024
1 parent 83db19d commit 1bd4fbe
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 37 deletions.
14 changes: 11 additions & 3 deletions ibis/backends/dask/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,25 @@
ops.DateAdd: lambda row: row["left"] + row["right"],
}


def maybe_pandas_reduction(func):
def inner(df):
return df.reduction(func) if isinstance(df, dd.Series) else func(df)

return inner


reductions = {
**pandas_kernels.reductions,
ops.Mode: lambda x: x.mode().loc[0],
ops.ApproxMedian: lambda x: x.median_approximate(),
ops.BitAnd: lambda x: x.reduction(np.bitwise_and.reduce),
ops.BitOr: lambda x: x.reduction(np.bitwise_or.reduce),
ops.BitXor: lambda x: x.reduction(np.bitwise_xor.reduce),
ops.Arbitrary: lambda x: x.reduction(pandas_kernels.first),
# Window functions are calculated locally using pandas
ops.Last: lambda x: x.compute().iloc[-1] if isinstance(x, dd.Series) else x.iat[-1],
ops.First: lambda x: x.loc[0] if isinstance(x, dd.Series) else x.iat[0],
ops.Arbitrary: lambda x: x.reduction(pandas_kernels.arbitrary),
ops.Last: maybe_pandas_reduction(pandas_kernels.last),
ops.First: maybe_pandas_reduction(pandas_kernels.first),
}

serieswise = {
Expand Down
11 changes: 9 additions & 2 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from ibis.backends.sql.compiler import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DataFusionType
from ibis.backends.sql.dialects import DataFusion
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.backends.sql.rewrites import (
exclude_nulls_from_array_collect,
exclude_nulls_from_first_last,
)
from ibis.common.temporal import IntervalUnit, TimestampUnit
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowType
Expand All @@ -26,7 +29,11 @@ class DataFusionCompiler(SQLGlotCompiler):
dialect = DataFusion
type_mapper = DataFusionType

rewrites = (exclude_nulls_from_array_collect, *SQLGlotCompiler.rewrites)
rewrites = (
exclude_nulls_from_array_collect,
exclude_nulls_from_first_last,
*SQLGlotCompiler.rewrites,
)

agg = AggGen(supports_filter=True)

Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.backends.sql.rewrites import (
exclude_nulls_from_array_collect,
exclude_nulls_from_first_last,
)

_INTERVAL_SUFFIXES = {
"ms": "milliseconds",
Expand All @@ -38,6 +41,7 @@ class DuckDBCompiler(SQLGlotCompiler):

rewrites = (
exclude_nulls_from_array_collect,
exclude_nulls_from_first_last,
*SQLGlotCompiler.rewrites,
)

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,9 @@ class ExasolCompiler(SQLGlotCompiler):
ops.DateFromYMD,
ops.DayOfWeekIndex,
ops.ElementWiseVectorizedUDF,
ops.First,
ops.IntervalFromInteger,
ops.IsInf,
ops.IsNan,
ops.Last,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
Expand Down Expand Up @@ -90,6 +88,8 @@ class ExasolCompiler(SQLGlotCompiler):
ops.Log10: "log10",
ops.All: "min",
ops.Any: "max",
ops.First: "first_value",
ops.Last: "last_value",
}

@staticmethod
Expand Down
16 changes: 11 additions & 5 deletions ibis/backends/pandas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,18 @@ def round_serieswise(arg, digits):
return np.round(arg, digits).astype("float64")


def arbitrary(arg):
# Arbitrary excludes null values unless they're all null
def first(arg):
# first excludes null values unless they're all null
arg = arg.dropna()
return arg.iat[0] if len(arg) else None


def last(arg):
# last excludes null values unless they're all null
arg = arg.dropna()
return arg.iat[-1] if len(arg) else None


reductions = {
ops.Min: lambda x: x.min(),
ops.Max: lambda x: x.max(),
Expand All @@ -274,9 +280,9 @@ def arbitrary(arg):
ops.BitAnd: lambda x: np.bitwise_and.reduce(x.values),
ops.BitOr: lambda x: np.bitwise_or.reduce(x.values),
ops.BitXor: lambda x: np.bitwise_xor.reduce(x.values),
ops.Last: lambda x: x.iat[-1],
ops.First: lambda x: x.iat[0],
ops.Arbitrary: arbitrary,
ops.Last: last,
ops.First: first,
ops.Arbitrary: first,
ops.CountDistinct: lambda x: x.nunique(),
ops.ApproxCountDistinct: lambda x: x.nunique(),
ops.ArrayCollect: lambda x: x.dropna().tolist(),
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,14 @@ def exclude_nulls_from_array_collect(_, **kwargs):
return _.copy(where=where)


@replace(p.First | p.Last)
def exclude_nulls_from_first_last(_, **kwargs):
where = ops.NotNull(_.arg)
if _.where is not None:
where = ops.And(where, _.where)
return _.copy(where=where)


# Rewrite rules for lowering a high-level operation into one composed of more
# primitive operations.

Expand Down
54 changes: 30 additions & 24 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,30 +486,6 @@ def mean_and_std(v):
pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError),
],
),
param(
lambda t, where: t.double_col.first(where=where),
lambda t, where: t.double_col[where].iloc[0],
id="first",
marks=[
pytest.mark.notimpl(
["druid", "impala", "mssql", "mysql", "oracle"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError),
],
),
param(
lambda t, where: t.double_col.last(where=where),
lambda t, where: t.double_col[where].iloc[-1],
id="last",
marks=[
pytest.mark.notimpl(
["druid", "impala", "mssql", "mysql", "oracle"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError),
],
),
param(
lambda t, where: t.bigint_col.bit_and(where=where),
lambda t, where: np.bitwise_and.reduce(t.bigint_col[where].values),
Expand Down Expand Up @@ -641,6 +617,36 @@ def test_reduction_ops(
np.testing.assert_array_equal(result, expected)


@pytest.mark.notimpl(
["druid", "impala", "mssql", "mysql", "oracle"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError)
@pytest.mark.parametrize("method", ["first", "last"])
@pytest.mark.parametrize("filtered", [False, True])
def test_first_last(backend, alltypes, method, filtered):
# `first` and `last` effectively choose an arbitrary value when no
# additional order is specified. *Most* backends will result in the
# first/last element in a column being selected (at least when operating on
# a leaf table), but that's really not guaranteed. These operations need an
# order to be meaningful.
#
# To sanely test this we create a column that is a mix of nulls and a
# single value (or a single value after filtering is applied).
if filtered:
new = alltypes.int_col.cases([(3, 30), (4, 40)])
where = _.int_col == 3
else:
new = (alltypes.int_col == 3).ifelse(30, None)
where = None

t = alltypes.mutate(new=new)

expr = getattr(t.new, method)(where=where)
res = expr.execute()
assert res == 30


@pytest.mark.notimpl(
[
"impala",
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ibis.backends.sql.dialects import Trino
from ibis.backends.sql.rewrites import (
exclude_nulls_from_array_collect,
exclude_nulls_from_first_last,
exclude_unsupported_window_frame_from_ops,
)

Expand All @@ -29,6 +30,7 @@ class TrinoCompiler(SQLGlotCompiler):

rewrites = (
exclude_nulls_from_array_collect,
exclude_nulls_from_first_last,
exclude_unsupported_window_frame_from_ops,
*SQLGlotCompiler.rewrites,
)
Expand Down

0 comments on commit 1bd4fbe

Please sign in to comment.