Skip to content

Commit

Permalink
feat(api): support ignore_null in first/last
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Aug 6, 2024
1 parent d584721 commit bb073c3
Show file tree
Hide file tree
Showing 27 changed files with 352 additions and 102 deletions.
36 changes: 36 additions & 0 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,42 @@ def agg(df):

return agg

@classmethod
def visit(cls, op: ops.First, arg, where, order_by, ignore_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

def first(df):
def inner(arg):
if ignore_null:
arg = arg.dropna()
return arg.iat[0] if len(arg) else None

return df.reduction(inner) if isinstance(df, dd.Series) else inner(df)

return cls.agg(first, arg, where)

@classmethod
def visit(cls, op: ops.Last, arg, where, order_by, ignore_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

def last(df):
def inner(arg):
if ignore_null:
arg = arg.dropna()
return arg.iat[-1] if len(arg) else None

return df.reduction(inner) if isinstance(df, dd.Series) else inner(df)

return cls.agg(last, arg, where)

@classmethod
def visit(cls, op: ops.Correlation, left, right, where, how):
if how == "pop":
Expand Down
12 changes: 1 addition & 11 deletions ibis/backends/dask/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,14 @@
}


def maybe_pandas_reduction(func):
def inner(df):
return df.reduction(func) if isinstance(df, dd.Series) else func(df)

return inner


reductions = {
**pandas_kernels.reductions,
ops.Mode: lambda x: x.mode().loc[0],
ops.ApproxMedian: lambda x: x.median_approximate(),
ops.BitAnd: lambda x: x.reduction(np.bitwise_and.reduce),
ops.BitOr: lambda x: x.reduction(np.bitwise_or.reduce),
ops.BitXor: lambda x: x.reduction(np.bitwise_xor.reduce),
ops.Arbitrary: lambda x: x.reduction(pandas_kernels.first),
# Window functions are calculated locally using pandas
ops.Last: maybe_pandas_reduction(pandas_kernels.last),
ops.First: maybe_pandas_reduction(pandas_kernels.first),
ops.Arbitrary: lambda x: x.reduction(pandas_kernels.arbitrary),
}

serieswise = {
Expand Down
30 changes: 30 additions & 0 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,36 @@ def visit(cls, op: ops.ArrayCollect, arg, where, order_by, ignore_null):
(lambda x: x.dropna().tolist() if ignore_null else x.tolist()), arg, where
)

@classmethod
def visit(cls, op: ops.First, arg, where, order_by, ignore_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

def first(arg):
if ignore_null:
arg = arg.dropna()
return arg.iat[0] if len(arg) else None

return cls.agg(first, arg, where)

@classmethod
def visit(cls, op: ops.Last, arg, where, order_by, ignore_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

def last(arg):
if ignore_null:
arg = arg.dropna()
return arg.iat[-1] if len(arg) else None

return cls.agg(last, arg, where)

@classmethod
def visit(cls, op: ops.Correlation, left, right, where, how):
if where is None:
Expand Down
13 changes: 2 additions & 11 deletions ibis/backends/pandas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,18 +260,11 @@ def round_serieswise(arg, digits):
return np.round(arg, digits).astype("float64")


def first(arg):
# first excludes null values unless they're all null
def arbitrary(arg):
arg = arg.dropna()
return arg.iat[0] if len(arg) else None


def last(arg):
# last excludes null values unless they're all null
arg = arg.dropna()
return arg.iat[-1] if len(arg) else None


reductions = {
ops.Min: lambda x: x.min(),
ops.Max: lambda x: x.max(),
Expand All @@ -286,9 +279,7 @@ def last(arg):
ops.BitAnd: lambda x: np.bitwise_and.reduce(x.values),
ops.BitOr: lambda x: np.bitwise_or.reduce(x.values),
ops.BitXor: lambda x: np.bitwise_xor.reduce(x.values),
ops.Last: last,
ops.First: first,
ops.Arbitrary: first,
ops.Arbitrary: arbitrary,
ops.CountDistinct: lambda x: x.nunique(),
ops.ApproxCountDistinct: lambda x: x.nunique(),
}
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,7 @@ def execute_reduction(op, **kw):
def execute_first_last(op, **kw):
arg = translate(op.arg, **kw)

# polars doesn't ignore nulls by default for these methods
predicate = arg.is_not_null()
predicate = arg.is_not_null() if getattr(op, "ignore_null", True) else True
if op.where is not None:
predicate &= translate(op.where, **kw)

Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,15 +330,13 @@ class SQLGlotCompiler(abc.ABC):
ops.Degrees: "degrees",
ops.DenseRank: "dense_rank",
ops.Exp: "exp",
ops.First: "first",
FirstValue: "first_value",
ops.GroupConcat: "group_concat",
ops.IfElse: "if",
ops.IsInf: "isinf",
ops.IsNan: "isnan",
ops.JSONGetItem: "json_extract",
ops.LPad: "lpad",
ops.Last: "last",
LastValue: "last_value",
ops.Levenshtein: "levenshtein",
ops.Ln: "ln",
Expand Down
31 changes: 25 additions & 6 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@ def visit_StringToTimestamp(self, op, *, arg, format_str):
return self.f.parse_datetime(format_str, arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if where is not None and not ignore_null:
raise com.UnsupportedOperationError(
"Combining `ignore_null=False` and `where` is not supported "
"by bigquery"
)
out = self.agg.array_agg(arg, where=where, order_by=order_by)
if ignore_null:
out = sge.IgnoreNulls(this=out)
Expand Down Expand Up @@ -690,26 +695,40 @@ def visit_TimestampRange(self, op, *, start, stop, step):
self.f.generate_timestamp_array, start, stop, step, op.step.dtype
)

def visit_First(self, op, *, arg, where, order_by):
def visit_First(self, op, *, arg, where, order_by, ignore_null):
if where is not None:
arg = self.if_(where, arg, NULL)
if not ignore_null:
raise com.UnsupportedOperationError(
"Combining `ignore_null=False` and `where` is not supported "
"by bigquery"
)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

array = self.f.array_agg(
sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)),
)
if ignore_null:
arg = sge.IgnoreNulls(this=arg)

array = self.f.array_agg(sge.Limit(this=arg, expression=sge.convert(1)))
return array[self.f.safe_offset(0)]

def visit_Last(self, op, *, arg, where, order_by):
def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if where is not None:
arg = self.if_(where, arg, NULL)
if not ignore_null:
raise com.UnsupportedOperationError(
"Combining `ignore_null=False` and `where` is not supported "
"by bigquery"
)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg)))
if ignore_null:
arg = sge.IgnoreNulls(this=arg)

array = self.f.array_reverse(self.f.array_agg(arg))
return array[self.f.safe_offset(0)]

def visit_ArrayFilter(self, op, *, arg, body, param):
Expand Down
18 changes: 15 additions & 3 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,11 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.ExtractWeekOfYear: "toISOWeek",
ops.ExtractYear: "toYear",
ops.ExtractIsoYear: "toISOYear",
ops.First: "any",
ops.IntegerRange: "range",
ops.IsInf: "isInfinite",
ops.IsNan: "isNaN",
ops.IsNull: "isNull",
ops.LStrip: "trimLeft",
ops.Last: "anyLast",
ops.Ln: "log",
ops.Log10: "log10",
ops.MapKeys: "mapKeys",
Expand Down Expand Up @@ -606,10 +604,24 @@ def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str:
def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the pyspark backend"
"`ignore_null=False` is not supported by the clickhouse backend"
)
return self.agg.groupArray(arg, where=where, order_by=order_by)

def visit_First(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the clickhouse backend"
)
return self.agg.any(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the clickhouse backend"
)
return self.agg.anyLast(arg, where=where, order_by=order_by)

def visit_CountDistinctStar(
self, op: ops.CountDistinctStar, *, where, **_: Any
) -> str:
Expand Down
14 changes: 8 additions & 6 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,16 @@ def visit_StringConcat(self, op, *, arg):
sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg)
)

def visit_First(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_First(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Aggregate(self, op, *, parent, groups, metrics):
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ class DruidCompiler(SQLGlotCompiler):
ops.DateFromYMD,
ops.DayOfWeekIndex,
ops.DayOfWeekName,
ops.First,
ops.IntervalFromInteger,
ops.IsNan,
ops.IsInf,
ops.Last,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
Expand Down
14 changes: 8 additions & 6 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,14 +510,16 @@ def visit_RegexReplace(self, op, *, arg, pattern, replacement):
arg, pattern, replacement, "g", dialect=self.dialect
)

def visit_First(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_First(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last(arg, where=where, order_by=order_by)

def visit_Quantile(self, op, *, arg, quantile, where):
Expand Down
16 changes: 14 additions & 2 deletions ibis/backends/sql/compilers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ class ExasolCompiler(SQLGlotCompiler):
ops.Log10: "log10",
ops.All: "min",
ops.Any: "max",
ops.First: "first_value",
ops.Last: "last_value",
}

@staticmethod
Expand Down Expand Up @@ -136,6 +134,20 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by):

return sge.GroupConcat(this=arg, separator=sep)

def visit_First(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the exasol backend"
)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the exasol backend"
)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_StartsWith(self, op, *, arg, start):
return self.f.left(arg, self.f.length(start)).eq(start)

Expand Down
16 changes: 14 additions & 2 deletions ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ArrayRemove: "array_remove",
ops.ArrayUnion: "array_union",
ops.ExtractDayOfYear: "dayofyear",
ops.First: "first_value",
ops.Last: "last_value",
ops.MapKeys: "map_keys",
ops.MapValues: "map_values",
ops.Power: "power",
Expand Down Expand Up @@ -307,6 +305,20 @@ def visit_ArraySlice(self, op, *, arg, start, stop):

return self.f.array_slice(*args)

def visit_First(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the flink backend"
)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the flink backend"
)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Not(self, op, *, arg):
return sg.not_(self.cast(arg, dt.boolean))

Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ class ImpalaCompiler(SQLGlotCompiler):
ops.Covariance,
ops.DateDelta,
ops.ExtractDayOfYear,
ops.First,
ops.Last,
ops.Levenshtein,
ops.Map,
ops.Median,
Expand Down
Loading

0 comments on commit bb073c3

Please sign in to comment.