diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index f3c1cdaf4d126..5e7a794537011 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -319,6 +319,17 @@ def visit(cls, op: ops.StandardDev, arg, where, how): ddof = {"pop": 0, "sample": 1}[how] return cls.agg(lambda x: x.std(ddof=ddof), arg, where) + @classmethod + def visit(cls, op: ops.ArrayCollect, arg, where, order_by, ignore_null): + if order_by: + raise UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + return cls.agg( + (lambda x: x.dropna().tolist() if ignore_null else x.tolist()), arg, where + ) + @classmethod def visit(cls, op: ops.Correlation, left, right, where, how): if where is None: diff --git a/ibis/backends/pandas/kernels.py b/ibis/backends/pandas/kernels.py index 61dc278c5562f..c6feb6c3e8862 100644 --- a/ibis/backends/pandas/kernels.py +++ b/ibis/backends/pandas/kernels.py @@ -291,7 +291,6 @@ def last(arg): ops.Arbitrary: first, ops.CountDistinct: lambda x: x.nunique(), ops.ApproxCountDistinct: lambda x: x.nunique(), - ops.ArrayCollect: lambda x: x.dropna().tolist(), } diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 6a9fe4a3baf9a..d0ac2163f7963 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -991,7 +991,7 @@ def array_column(op, **kw): def array_collect(op, in_group_by=False, **kw): arg = translate(op.arg, **kw) - predicate = arg.is_not_null() + predicate = arg.is_not_null() if op.ignore_null else True if op.where is not None: predicate &= translate(op.where, **kw) diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 6e201b52fa719..418268b8b98c1 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -313,7 +313,6 @@ class SQLGlotCompiler(abc.ABC): ops.ApproxCountDistinct: "approx_distinct", ops.ArgMax: "max_by", ops.ArgMin: "min_by", - ops.ArrayCollect: "array_agg", ops.ArrayContains: "array_contains", ops.ArrayFlatten: "flatten", ops.ArrayLength: "array_size", diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index e0c225d453232..b5818e35468c7 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -440,10 +440,11 @@ def visit_StringToTimestamp(self, op, *, arg, format_str): return self.f.parse_timestamp(format_str, arg, timezone) return self.f.parse_datetime(format_str, arg) - def visit_ArrayCollect(self, op, *, arg, where, order_by): - return sge.IgnoreNulls( - this=self.agg.array_agg(arg, where=where, order_by=order_by) - ) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + out = self.agg.array_agg(arg, where=where, order_by=order_by) + if ignore_null: + out = sge.IgnoreNulls(this=out) + return out def _neg_idx_to_pos(self, arg, idx): return self.if_(idx < 0, self.f.array_length(arg) + idx, idx) diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index ec94f7f4ebc0a..00236f96df990 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -61,7 +61,6 @@ class ClickHouseCompiler(SQLGlotCompiler): ops.Arbitrary: "any", ops.ArgMax: "argMax", ops.ArgMin: "argMin", - ops.ArrayCollect: "groupArray", ops.ArrayContains: "has", ops.ArrayFlatten: "arrayFlatten", ops.ArrayIntersect: "arrayIntersect", @@ -604,6 +603,13 @@ def visit_ArrayUnion(self, op, *, left, right): def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str: return self.f.arrayZip(*arg) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the pyspark backend" + ) + return self.agg.groupArray(arg, where=where, order_by=order_by) + def visit_CountDistinctStar( self, op: ops.CountDistinctStar, *, where, **_: Any ) -> str: diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index e3c4ba65478d4..acc6dfc224314 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -14,7 +14,6 @@ from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DataFusionType from ibis.backends.sql.dialects import DataFusion -from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect from ibis.common.temporal import IntervalUnit, TimestampUnit from ibis.expr.operations.udf import InputType @@ -25,11 +24,6 @@ class DataFusionCompiler(SQLGlotCompiler): dialect = DataFusion type_mapper = DataFusionType - rewrites = ( - exclude_nulls_from_array_collect, - *SQLGlotCompiler.rewrites, - ) - agg = AggGen(supports_filter=True, supports_order_by=True) UNSUPPORTED_OPS = ( @@ -331,6 +325,12 @@ def visit_ArrayRepeat(self, op, *, arg, times): def visit_ArrayPosition(self, op, *, arg, other): return self.f.coalesce(self.f.array_position(arg, other), 0) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) + return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_Covariance(self, op, *, left, right, how, where): x = op.left if x.dtype.is_boolean(): diff --git a/ibis/backends/sql/compilers/druid.py b/ibis/backends/sql/compilers/druid.py index d1571363e14e6..6479628d9acb8 100644 --- a/ibis/backends/sql/compilers/druid.py +++ b/ibis/backends/sql/compilers/druid.py @@ -27,7 +27,6 @@ class DruidCompiler(SQLGlotCompiler): ops.ApproxMedian, ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.ArrayDistinct, ops.ArrayFilter, ops.ArrayFlatten, diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index 1724885ec0bbd..7498ab1f02772 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -14,7 +14,6 @@ from ibis import util from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DuckDBType -from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect from ibis.util import gen_name if TYPE_CHECKING: @@ -43,11 +42,6 @@ class DuckDBCompiler(SQLGlotCompiler): agg = AggGen(supports_filter=True, supports_order_by=True) - rewrites = ( - exclude_nulls_from_array_collect, - *SQLGlotCompiler.rewrites, - ) - LOWERED_OPS = { ops.Sample: None, ops.StringSlice: None, @@ -154,6 +148,12 @@ def visit_ArrayDistinct(self, op, *, arg): ), ) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) + return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_ArrayIndex(self, op, *, arg, index): return self.f.list_extract(arg, index + self.cast(index >= 0, op.index.dtype)) diff --git a/ibis/backends/sql/compilers/exasol.py b/ibis/backends/sql/compilers/exasol.py index 87f5aaa543d33..bdf690a7ff2bc 100644 --- a/ibis/backends/sql/compilers/exasol.py +++ b/ibis/backends/sql/compilers/exasol.py @@ -35,7 +35,6 @@ class ExasolCompiler(SQLGlotCompiler): ops.ApproxMedian, ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.ArrayDistinct, ops.ArrayFilter, ops.ArrayFlatten, diff --git a/ibis/backends/sql/compilers/flink.py b/ibis/backends/sql/compilers/flink.py index 4e4fb15864152..fb1fad34cfffb 100644 --- a/ibis/backends/sql/compilers/flink.py +++ b/ibis/backends/sql/compilers/flink.py @@ -71,7 +71,6 @@ class FlinkCompiler(SQLGlotCompiler): ops.ApproxMedian, ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.ArrayFlatten, ops.ArraySort, ops.ArrayStringJoin, diff --git a/ibis/backends/sql/compilers/impala.py b/ibis/backends/sql/compilers/impala.py index f73a38751d08e..bae861126d163 100644 --- a/ibis/backends/sql/compilers/impala.py +++ b/ibis/backends/sql/compilers/impala.py @@ -26,7 +26,6 @@ class ImpalaCompiler(SQLGlotCompiler): UNSUPPORTED_OPS = ( ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.ArrayPosition, ops.Array, ops.Covariance, diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index 0c6fe9a567ac5..d7b815252e9ed 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -75,7 +75,6 @@ class MSSQLCompiler(SQLGlotCompiler): ops.ApproxMedian, ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.Array, ops.ArrayDistinct, ops.ArrayFlatten, diff --git a/ibis/backends/sql/compilers/mysql.py b/ibis/backends/sql/compilers/mysql.py index 5244e2642b528..56c6f799c89aa 100644 --- a/ibis/backends/sql/compilers/mysql.py +++ b/ibis/backends/sql/compilers/mysql.py @@ -67,7 +67,6 @@ def POS_INF(self): ops.ApproxMedian, ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.Array, ops.ArrayFlatten, ops.ArrayMap, diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index ee98dda5e8420..98bea4d2d90b4 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -51,7 +51,6 @@ class OracleCompiler(SQLGlotCompiler): UNSUPPORTED_OPS = ( ops.ArgMax, ops.ArgMin, - ops.ArrayCollect, ops.Array, ops.ArrayFlatten, ops.ArrayMap, diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index 9d47044aa835c..f145ac7791111 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -17,7 +17,6 @@ from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import PostgresType from ibis.backends.sql.dialects import Postgres -from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect from ibis.common.exceptions import InvalidDecoratorError from ibis.util import gen_name @@ -43,8 +42,6 @@ class PostgresCompiler(SQLGlotCompiler): dialect = Postgres type_mapper = PostgresType - rewrites = (exclude_nulls_from_array_collect, *SQLGlotCompiler.rewrites) - agg = AggGen(supports_filter=True, supports_order_by=True) NAN = sge.Literal.number("'NaN'::double precision") @@ -358,6 +355,12 @@ def visit_ArrayIntersect(self, op, *, left, right): ) ) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) + return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_Log2(self, op, *, arg): return self.cast( self.f.log( diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index ee6060e4bfccb..9a4c292d5d114 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -397,6 +397,13 @@ def visit_ArrayContains(self, op, *, arg, other): def visit_ArrayStringJoin(self, op, *, arg, sep): return self.f.concat_ws(sep, arg) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the pyspark backend" + ) + return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_StringFind(self, op, *, arg, substr, start, end): if end is not None: raise com.UnsupportedOperationError( diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index c6ab7451ddbd6..75444d2f588a7 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -455,7 +455,12 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short]) - def _array_collect(self, *, arg, where, order_by): + def _array_collect(self, *, arg, where, order_by, ignore_null=True): + if not ignore_null: + raise com.UnsupportedOperationError( + "`ignore_null=False` is not supported by the snowflake backend" + ) + if where is not None: arg = self.if_(where, arg, NULL) @@ -466,8 +471,10 @@ def _array_collect(self, *, arg, where, order_by): return out - def visit_ArrayCollect(self, op, *, arg, where, order_by): - return self._array_collect(arg=arg, where=where, order_by=order_by) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + return self._array_collect( + arg=arg, where=where, order_by=order_by, ignore_null=ignore_null + ) def visit_First(self, op, *, arg, where, order_by): out = self._array_collect(arg=arg, where=where, order_by=order_by) diff --git a/ibis/backends/sql/compilers/sqlite.py b/ibis/backends/sql/compilers/sqlite.py index 88f86c5bb979d..9f0d4bf7c1dcc 100644 --- a/ibis/backends/sql/compilers/sqlite.py +++ b/ibis/backends/sql/compilers/sqlite.py @@ -42,7 +42,6 @@ class SQLiteCompiler(SQLGlotCompiler): ops.Array, ops.ArrayConcat, ops.ArrayStringJoin, - ops.ArrayCollect, ops.ArrayContains, ops.ArrayFlatten, ops.ArrayLength, diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index 90ee114c893c9..f35653c79712a 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -22,7 +22,6 @@ from ibis.backends.sql.datatypes import TrinoType from ibis.backends.sql.dialects import Trino from ibis.backends.sql.rewrites import ( - exclude_nulls_from_array_collect, exclude_unsupported_window_frame_from_ops, ) from ibis.util import gen_name @@ -37,7 +36,6 @@ class TrinoCompiler(SQLGlotCompiler): agg = AggGen(supports_filter=True, supports_order_by=True) rewrites = ( - exclude_nulls_from_array_collect, exclude_unsupported_window_frame_from_ops, *SQLGlotCompiler.rewrites, ) @@ -178,6 +176,12 @@ def visit_ArrayContains(self, op, *, arg, other): NULL, ) + def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null): + if ignore_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) + return self.agg.array_agg(arg, where=where, order_by=order_by) + def visit_JSONGetItem(self, op, *, arg, index): fmt = "%d" if op.index.dtype.is_integer() else '"%s"' return self.f.json_extract(arg, self.f.format(f"$[{fmt}]", index)) diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index 5110523093054..7e01999144ecb 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -386,14 +386,6 @@ def exclude_unsupported_window_frame_from_ops(_, **kwargs): return _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,)) -@replace(p.ArrayCollect) -def exclude_nulls_from_array_collect(_, **kwargs): - where = ops.NotNull(_.arg) - if _.where is not None: - where = ops.And(where, _.where) - return _.copy(where=where) - - # Rewrite rules for lowering a high-level operation into one composed of more # primitive operations. diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 992251c6b90fa..15d7cbcb6018b 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -529,28 +529,6 @@ def mean_and_std(v): lambda t, where: len(t[where]), id="count_star", ), - param( - lambda t, where: t.string_col.nullif("3").collect(where=where), - lambda t, where: t.string_col[t.string_col != "3"][where].tolist(), - id="collect", - marks=[ - pytest.mark.notimpl( - ["impala", "mysql", "sqlite", "mssql", "druid", "oracle", "exasol"], - raises=com.OperationNotDefinedError, - ), - pytest.mark.notimpl( - ["dask"], - raises=(AttributeError, TypeError), - reason=( - "For 'is_in' case: 'Series' object has no attribute 'arraycollect'" - "For 'no_cond' case: TypeError: Object " - " is not " - "callable or a string" - ), - ), - pytest.mark.notyet(["flink"], raises=com.OperationNotDefinedError), - ], - ), ], ) @pytest.mark.parametrize( @@ -1397,6 +1375,59 @@ def test_collect_ordered(alltypes, df, filtered): assert result == expected +@pytest.mark.notimpl( + ["druid", "exasol", "flink", "impala", "mssql", "mysql", "oracle", "sqlite"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notimpl( + ["dask"], raises=AttributeError, reason="Dask doesn't implement tolist()" +) +@pytest.mark.parametrize( + "filtered", + [ + param( + True, + marks=[ + pytest.mark.notyet( + ["datafusion"], + raises=Exception, + reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail", + ) + ], + ), + False, + ], +) +@pytest.mark.parametrize( + "ignore_null", + [ + True, + param( + False, + marks=[ + pytest.mark.notimpl( + ["clickhouse", "pyspark", "snowflake"], + raises=com.UnsupportedOperationError, + reason="`ignore_null=False` is not supported", + ) + ], + ), + ], +) +def test_collect(alltypes, df, filtered, ignore_null): + ibis_cond = (_.id % 13 == 0) if filtered else None + pd_cond = (df.id % 13 == 0) if filtered else slice(None) + res = ( + alltypes.string_col.nullif("3") + .collect(where=ibis_cond, ignore_null=ignore_null) + .length() + .execute() + ) + vals = df.string_col[(df.string_col != "3")] if ignore_null else df.string_col + sol = len(vals[pd_cond]) + assert res == sol + + @pytest.mark.notimpl(["mssql"], raises=PyODBCProgrammingError) def test_topk_op(alltypes, df): # TopK expression will order rows by "count" but each backend diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 0e093f536b458..f5b1128a8694f 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -368,6 +368,7 @@ class ArrayCollect(Filterable, Reduction): arg: Column order_by: VarTuple[SortKey] = () + ignore_null: bool = True @attribute def dtype(self): diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 473b74fc5d4c9..9cb19b73c5b19 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1018,7 +1018,10 @@ def cases( return builder.else_(default).end() def collect( - self, where: ir.BooleanValue | None = None, order_by: Any = None + self, + where: ir.BooleanValue | None = None, + order_by: Any = None, + ignore_null: bool = True, ) -> ir.ArrayScalar: """Aggregate this expression's elements into an array. @@ -1033,6 +1036,9 @@ def collect( An ordering key (or keys) to use to order the rows before aggregating. If not provided, the order of the items in the result is undefined and backend specific. + ignore_null + Whether to ignore null values when performing this aggregation. Set + to `False` to include nulls in the result. Returns ------- @@ -1093,6 +1099,7 @@ def collect( self, where=self._bind_to_parent_table(where), order_by=self._bind_order_by(order_by), + ignore_null=ignore_null, ).to_expr() def identical_to(self, other: Value) -> ir.BooleanValue: