Skip to content

Commit

Permalink
feat(api): support ignore_null in collect
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Aug 6, 2024
1 parent 84a786d commit 21f3c1a
Show file tree
Hide file tree
Showing 24 changed files with 130 additions and 68 deletions.
11 changes: 11 additions & 0 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,17 @@ def visit(cls, op: ops.StandardDev, arg, where, how):
ddof = {"pop": 0, "sample": 1}[how]
return cls.agg(lambda x: x.std(ddof=ddof), arg, where)

@classmethod
def visit(cls, op: ops.ArrayCollect, arg, where, order_by, ignore_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
return cls.agg(
(lambda x: x.dropna().tolist() if ignore_null else x.tolist()), arg, where
)

@classmethod
def visit(cls, op: ops.Correlation, left, right, where, how):
if where is None:
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/pandas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def last(arg):
ops.Arbitrary: first,
ops.CountDistinct: lambda x: x.nunique(),
ops.ApproxCountDistinct: lambda x: x.nunique(),
ops.ArrayCollect: lambda x: x.dropna().tolist(),
}


Expand Down Expand Up @@ -382,7 +381,7 @@ def wrapper(*args, **kwargs):
ops.IfElse: lambda df: df["true_expr"].where(
df["bool_expr"], other=df["false_null_expr"]
),
ops.NullIf: lambda df: df["arg"].where(df["arg"] != df["null_if_expr"]),
ops.NullIf: lambda df: df["arg"].where(df["arg"] != df["null_if_expr"], None),
ops.Repeat: lambda df: df["arg"] * df["times"],
}

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def array_column(op, **kw):
def array_collect(op, in_group_by=False, **kw):
arg = translate(op.arg, **kw)

predicate = arg.is_not_null()
predicate = arg.is_not_null() if op.ignore_null else True
if op.where is not None:
predicate &= translate(op.where, **kw)

Expand Down
1 change: 0 additions & 1 deletion ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ class SQLGlotCompiler(abc.ABC):
ops.ApproxCountDistinct: "approx_distinct",
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.ArrayCollect: "array_agg",
ops.ArrayContains: "array_contains",
ops.ArrayFlatten: "flatten",
ops.ArrayLength: "array_size",
Expand Down
9 changes: 5 additions & 4 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,11 @@ def visit_StringToTimestamp(self, op, *, arg, format_str):
return self.f.parse_timestamp(format_str, arg, timezone)
return self.f.parse_datetime(format_str, arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by):
return sge.IgnoreNulls(
this=self.agg.array_agg(arg, where=where, order_by=order_by)
)
def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
out = self.agg.array_agg(arg, where=where, order_by=order_by)
if ignore_null:
out = sge.IgnoreNulls(this=out)
return out

def _neg_idx_to_pos(self, arg, idx):
return self.if_(idx < 0, self.f.array_length(arg) + idx, idx)
Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.Arbitrary: "any",
ops.ArgMax: "argMax",
ops.ArgMin: "argMin",
ops.ArrayCollect: "groupArray",
ops.ArrayContains: "has",
ops.ArrayFlatten: "arrayFlatten",
ops.ArrayIntersect: "arrayIntersect",
Expand Down Expand Up @@ -604,6 +603,13 @@ def visit_ArrayUnion(self, op, *, left, right):
def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str:
return self.f.arrayZip(*arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the pyspark backend"
)
return self.agg.groupArray(arg, where=where, order_by=order_by)

def visit_CountDistinctStar(
self, op: ops.CountDistinctStar, *, where, **_: Any
) -> str:
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DataFusionType
from ibis.backends.sql.dialects import DataFusion
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.common.temporal import IntervalUnit, TimestampUnit
from ibis.expr.operations.udf import InputType

Expand All @@ -25,11 +24,6 @@ class DataFusionCompiler(SQLGlotCompiler):
dialect = DataFusion
type_mapper = DataFusionType

rewrites = (
exclude_nulls_from_array_collect,
*SQLGlotCompiler.rewrites,
)

agg = AggGen(supports_filter=True, supports_order_by=True)

UNSUPPORTED_OPS = (
Expand Down Expand Up @@ -331,6 +325,12 @@ def visit_ArrayRepeat(self, op, *, arg, times):
def visit_ArrayPosition(self, op, *, arg, other):
return self.f.coalesce(self.f.array_position(arg, other), 0)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_Covariance(self, op, *, left, right, how, where):
x = op.left
if x.dtype.is_boolean():
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/sql/compilers/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class DruidCompiler(SQLGlotCompiler):
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ibis import util
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.util import gen_name

if TYPE_CHECKING:
Expand Down Expand Up @@ -43,11 +42,6 @@ class DuckDBCompiler(SQLGlotCompiler):

agg = AggGen(supports_filter=True, supports_order_by=True)

rewrites = (
exclude_nulls_from_array_collect,
*SQLGlotCompiler.rewrites,
)

LOWERED_OPS = {
ops.Sample: None,
ops.StringSlice: None,
Expand Down Expand Up @@ -154,6 +148,12 @@ def visit_ArrayDistinct(self, op, *, arg):
),
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_ArrayIndex(self, op, *, arg, index):
return self.f.list_extract(arg, index + self.cast(index >= 0, op.index.dtype))

Expand Down
1 change: 0 additions & 1 deletion ibis/backends/sql/compilers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class ExasolCompiler(SQLGlotCompiler):
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayFlatten,
ops.ArraySort,
ops.ArrayStringJoin,
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class ImpalaCompiler(SQLGlotCompiler):
UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayPosition,
ops.Array,
ops.Covariance,
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.Array,
ops.ArrayDistinct,
ops.ArrayFlatten,
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def POS_INF(self):
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.Array,
ops.ArrayFlatten,
ops.ArrayMap,
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class OracleCompiler(SQLGlotCompiler):
UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.Array,
ops.ArrayFlatten,
ops.ArrayMap,
Expand Down
9 changes: 6 additions & 3 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.common.exceptions import InvalidDecoratorError
from ibis.util import gen_name

Expand All @@ -43,8 +42,6 @@ class PostgresCompiler(SQLGlotCompiler):
dialect = Postgres
type_mapper = PostgresType

rewrites = (exclude_nulls_from_array_collect, *SQLGlotCompiler.rewrites)

agg = AggGen(supports_filter=True, supports_order_by=True)

NAN = sge.Literal.number("'NaN'::double precision")
Expand Down Expand Up @@ -358,6 +355,12 @@ def visit_ArrayIntersect(self, op, *, left, right):
)
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_Log2(self, op, *, arg):
return self.cast(
self.f.log(
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,13 @@ def visit_ArrayContains(self, op, *, arg, other):
def visit_ArrayStringJoin(self, op, *, arg, sep):
return self.f.concat_ws(sep, arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the pyspark backend"
)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_StringFind(self, op, *, arg, substr, start, end):
if end is not None:
raise com.UnsupportedOperationError(
Expand Down
13 changes: 10 additions & 3 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,12 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9}
return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short])

def _array_collect(self, *, arg, where, order_by):
def _array_collect(self, *, arg, where, order_by, ignore_null=True):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the snowflake backend"
)

if where is not None:
arg = self.if_(where, arg, NULL)

Expand All @@ -466,8 +471,10 @@ def _array_collect(self, *, arg, where, order_by):

return out

def visit_ArrayCollect(self, op, *, arg, where, order_by):
return self._array_collect(arg=arg, where=where, order_by=order_by)
def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
return self._array_collect(
arg=arg, where=where, order_by=order_by, ignore_null=ignore_null
)

def visit_First(self, op, *, arg, where, order_by):
out = self._array_collect(arg=arg, where=where, order_by=order_by)
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class SQLiteCompiler(SQLGlotCompiler):
ops.Array,
ops.ArrayConcat,
ops.ArrayStringJoin,
ops.ArrayCollect,
ops.ArrayContains,
ops.ArrayFlatten,
ops.ArrayLength,
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ibis.backends.sql.datatypes import TrinoType
from ibis.backends.sql.dialects import Trino
from ibis.backends.sql.rewrites import (
exclude_nulls_from_array_collect,
exclude_unsupported_window_frame_from_ops,
)
from ibis.util import gen_name
Expand All @@ -37,7 +36,6 @@ class TrinoCompiler(SQLGlotCompiler):
agg = AggGen(supports_filter=True, supports_order_by=True)

rewrites = (
exclude_nulls_from_array_collect,
exclude_unsupported_window_frame_from_ops,
*SQLGlotCompiler.rewrites,
)
Expand Down Expand Up @@ -178,6 +176,12 @@ def visit_ArrayContains(self, op, *, arg, other):
NULL,
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_JSONGetItem(self, op, *, arg, index):
fmt = "%d" if op.index.dtype.is_integer() else '"%s"'
return self.f.json_extract(arg, self.f.format(f"$[{fmt}]", index))
Expand Down
8 changes: 0 additions & 8 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,14 +386,6 @@ def exclude_unsupported_window_frame_from_ops(_, **kwargs):
return _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,))


@replace(p.ArrayCollect)
def exclude_nulls_from_array_collect(_, **kwargs):
where = ops.NotNull(_.arg)
if _.where is not None:
where = ops.And(where, _.where)
return _.copy(where=where)


# Rewrite rules for lowering a high-level operation into one composed of more
# primitive operations.

Expand Down
Loading

0 comments on commit 21f3c1a

Please sign in to comment.