Skip to content

Commit

Permalink
feat(api): support ignore_null in first/last
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Aug 7, 2024
1 parent 91aaec4 commit cd1ecab
Show file tree
Hide file tree
Showing 27 changed files with 384 additions and 134 deletions.
36 changes: 36 additions & 0 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,42 @@ def agg(df):

return agg

@classmethod
def visit(cls, op: ops.First, arg, where, order_by, include_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

def first(df):
def inner(arg):
if not include_null:
arg = arg.dropna()
return arg.iat[0] if len(arg) else None

return df.reduction(inner) if isinstance(df, dd.Series) else inner(df)

return cls.agg(first, arg, where)

@classmethod
def visit(cls, op: ops.Last, arg, where, order_by, include_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

def last(df):
def inner(arg):
if not include_null:
arg = arg.dropna()
return arg.iat[-1] if len(arg) else None

return df.reduction(inner) if isinstance(df, dd.Series) else inner(df)

return cls.agg(last, arg, where)

@classmethod
def visit(cls, op: ops.Correlation, left, right, where, how):
if how == "pop":
Expand Down
12 changes: 1 addition & 11 deletions ibis/backends/dask/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,14 @@
}


def maybe_pandas_reduction(func):
def inner(df):
return df.reduction(func) if isinstance(df, dd.Series) else func(df)

return inner


reductions = {
**pandas_kernels.reductions,
ops.Mode: lambda x: x.mode().loc[0],
ops.ApproxMedian: lambda x: x.median_approximate(),
ops.BitAnd: lambda x: x.reduction(np.bitwise_and.reduce),
ops.BitOr: lambda x: x.reduction(np.bitwise_or.reduce),
ops.BitXor: lambda x: x.reduction(np.bitwise_xor.reduce),
ops.Arbitrary: lambda x: x.reduction(pandas_kernels.first),
# Window functions are calculated locally using pandas
ops.Last: maybe_pandas_reduction(pandas_kernels.last),
ops.First: maybe_pandas_reduction(pandas_kernels.first),
ops.Arbitrary: lambda x: x.reduction(pandas_kernels.arbitrary),
}

serieswise = {
Expand Down
34 changes: 32 additions & 2 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,16 +320,46 @@ def visit(cls, op: ops.StandardDev, arg, where, how):
return cls.agg(lambda x: x.std(ddof=ddof), arg, where)

@classmethod
def visit(cls, op: ops.ArrayCollect, arg, where, order_by, ignore_null):
def visit(cls, op: ops.ArrayCollect, arg, where, order_by, include_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
return cls.agg(
(lambda x: x.dropna().tolist() if ignore_null else x.tolist()), arg, where
(lambda x: x.tolist() if include_null else x.dropna().tolist()), arg, where
)

@classmethod
def visit(cls, op: ops.First, arg, where, order_by, include_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

def first(arg):
if not include_null:
arg = arg.dropna()
return arg.iat[0] if len(arg) else None

return cls.agg(first, arg, where)

@classmethod
def visit(cls, op: ops.Last, arg, where, order_by, include_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

def last(arg):
if not include_null:
arg = arg.dropna()
return arg.iat[-1] if len(arg) else None

return cls.agg(last, arg, where)

@classmethod
def visit(cls, op: ops.Correlation, left, right, where, how):
if where is None:
Expand Down
13 changes: 2 additions & 11 deletions ibis/backends/pandas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,18 +260,11 @@ def round_serieswise(arg, digits):
return np.round(arg, digits).astype("float64")


def first(arg):
# first excludes null values unless they're all null
def arbitrary(arg):
arg = arg.dropna()
return arg.iat[0] if len(arg) else None


def last(arg):
# last excludes null values unless they're all null
arg = arg.dropna()
return arg.iat[-1] if len(arg) else None


reductions = {
ops.Min: lambda x: x.min(),
ops.Max: lambda x: x.max(),
Expand All @@ -286,9 +279,7 @@ def last(arg):
ops.BitAnd: lambda x: np.bitwise_and.reduce(x.values),
ops.BitOr: lambda x: np.bitwise_or.reduce(x.values),
ops.BitXor: lambda x: np.bitwise_xor.reduce(x.values),
ops.Last: last,
ops.First: first,
ops.Arbitrary: first,
ops.Arbitrary: arbitrary,
ops.CountDistinct: lambda x: x.nunique(),
ops.ApproxCountDistinct: lambda x: x.nunique(),
}
Expand Down
5 changes: 2 additions & 3 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,7 @@ def execute_reduction(op, **kw):
def execute_first_last(op, **kw):
arg = translate(op.arg, **kw)

# polars doesn't ignore nulls by default for these methods
predicate = arg.is_not_null()
predicate = True if getattr(op, "include_null", False) else arg.is_not_null()
if op.where is not None:
predicate &= translate(op.where, **kw)

Expand Down Expand Up @@ -991,7 +990,7 @@ def array_column(op, **kw):
def array_collect(op, in_group_by=False, **kw):
arg = translate(op.arg, **kw)

predicate = arg.is_not_null() if op.ignore_null else True
predicate = True if op.include_null else arg.is_not_null()
if op.where is not None:
predicate &= translate(op.where, **kw)

Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,15 +330,13 @@ class SQLGlotCompiler(abc.ABC):
ops.Degrees: "degrees",
ops.DenseRank: "dense_rank",
ops.Exp: "exp",
ops.First: "first",
FirstValue: "first_value",
ops.GroupConcat: "group_concat",
ops.IfElse: "if",
ops.IsInf: "isinf",
ops.IsNan: "isnan",
ops.JSONGetItem: "json_extract",
ops.LPad: "lpad",
ops.Last: "last",
LastValue: "last_value",
ops.Levenshtein: "levenshtein",
ops.Ln: "ln",
Expand Down
35 changes: 27 additions & 8 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,14 @@ def visit_StringToTimestamp(self, op, *, arg, format_str):
return self.f.parse_timestamp(format_str, arg, timezone)
return self.f.parse_datetime(format_str, arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if where is not None and include_null:
raise com.UnsupportedOperationError(
"Combining `include_null=True` and `where` is not supported "
"by bigquery"
)
out = self.agg.array_agg(arg, where=where, order_by=order_by)
if ignore_null:
if not include_null:
out = sge.IgnoreNulls(this=out)
return out

Expand Down Expand Up @@ -690,26 +695,40 @@ def visit_TimestampRange(self, op, *, start, stop, step):
self.f.generate_timestamp_array, start, stop, step, op.step.dtype
)

def visit_First(self, op, *, arg, where, order_by):
def visit_First(self, op, *, arg, where, order_by, include_null):
if where is not None:
arg = self.if_(where, arg, NULL)
if include_null:
raise com.UnsupportedOperationError(
"Combining `include_null=True` and `where` is not supported "
"by bigquery"
)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

array = self.f.array_agg(
sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)),
)
if not include_null:
arg = sge.IgnoreNulls(this=arg)

array = self.f.array_agg(sge.Limit(this=arg, expression=sge.convert(1)))
return array[self.f.safe_offset(0)]

def visit_Last(self, op, *, arg, where, order_by):
def visit_Last(self, op, *, arg, where, order_by, include_null):
if where is not None:
arg = self.if_(where, arg, NULL)
if include_null:
raise com.UnsupportedOperationError(
"Combining `include_null=True` and `where` is not supported "
"by bigquery"
)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg)))
if not include_null:
arg = sge.IgnoreNulls(this=arg)

array = self.f.array_reverse(self.f.array_agg(arg))
return array[self.f.safe_offset(0)]

def visit_ArrayFilter(self, op, *, arg, body, param):
Expand Down
22 changes: 17 additions & 5 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,11 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.ExtractWeekOfYear: "toISOWeek",
ops.ExtractYear: "toYear",
ops.ExtractIsoYear: "toISOYear",
ops.First: "any",
ops.IntegerRange: "range",
ops.IsInf: "isInfinite",
ops.IsNan: "isNaN",
ops.IsNull: "isNull",
ops.LStrip: "trimLeft",
ops.Last: "anyLast",
ops.Ln: "log",
ops.Log10: "log10",
ops.MapKeys: "mapKeys",
Expand Down Expand Up @@ -603,13 +601,27 @@ def visit_ArrayUnion(self, op, *, left, right):
def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str:
return self.f.arrayZip(*arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the pyspark backend"
"`include_null=True` is not supported by the clickhouse backend"
)
return self.agg.groupArray(arg, where=where, order_by=order_by)

def visit_First(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the clickhouse backend"
)
return self.agg.any(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the clickhouse backend"
)
return self.agg.anyLast(arg, where=where, order_by=order_by)

def visit_CountDistinctStar(
self, op: ops.CountDistinctStar, *, where, **_: Any
) -> str:
Expand Down
18 changes: 10 additions & 8 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ def visit_ArrayRepeat(self, op, *, arg, times):
def visit_ArrayPosition(self, op, *, arg, other):
return self.f.coalesce(self.f.array_position(arg, other), 0)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)
Expand Down Expand Up @@ -425,14 +425,16 @@ def visit_StringConcat(self, op, *, arg):
sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg)
)

def visit_First(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_First(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_Last(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Aggregate(self, op, *, parent, groups, metrics):
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ class DruidCompiler(SQLGlotCompiler):
ops.DateFromYMD,
ops.DayOfWeekIndex,
ops.DayOfWeekName,
ops.First,
ops.IntervalFromInteger,
ops.IsNan,
ops.IsInf,
ops.Last,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
Expand Down
18 changes: 10 additions & 8 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def visit_ArrayDistinct(self, op, *, arg):
),
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)
Expand Down Expand Up @@ -510,14 +510,16 @@ def visit_RegexReplace(self, op, *, arg, pattern, replacement):
arg, pattern, replacement, "g", dialect=self.dialect
)

def visit_First(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_First(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_Last(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last(arg, where=where, order_by=order_by)

def visit_Quantile(self, op, *, arg, quantile, where):
Expand Down
Loading

0 comments on commit cd1ecab

Please sign in to comment.