Skip to content

Commit

Permalink
feat(api): support ignore_null in first/last
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Aug 6, 2024
1 parent 21f3c1a commit 8d10a3e
Show file tree
Hide file tree
Showing 22 changed files with 238 additions and 77 deletions.
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,15 +330,13 @@ class SQLGlotCompiler(abc.ABC):
ops.Degrees: "degrees",
ops.DenseRank: "dense_rank",
ops.Exp: "exp",
ops.First: "first",
FirstValue: "first_value",
ops.GroupConcat: "group_concat",
ops.IfElse: "if",
ops.IsInf: "isinf",
ops.IsNan: "isnan",
ops.JSONGetItem: "json_extract",
ops.LPad: "lpad",
ops.Last: "last",
LastValue: "last_value",
ops.Levenshtein: "levenshtein",
ops.Ln: "ln",
Expand Down
16 changes: 10 additions & 6 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,26 +690,30 @@ def visit_TimestampRange(self, op, *, start, stop, step):
self.f.generate_timestamp_array, start, stop, step, op.step.dtype
)

def visit_First(self, op, *, arg, where, order_by):
def visit_First(self, op, *, arg, where, order_by, ignore_null):
if where is not None:
arg = self.if_(where, arg, NULL)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

array = self.f.array_agg(
sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)),
)
if ignore_null:
arg = sge.IgnoreNulls(this=arg)

array = self.f.array_agg(sge.Limit(this=arg, expression=sge.convert(1)))
return array[self.f.safe_offset(0)]

def visit_Last(self, op, *, arg, where, order_by):
def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if where is not None:
arg = self.if_(where, arg, NULL)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg)))
if ignore_null:
arg = sge.IgnoreNulls(this=arg)

array = self.f.array_reverse(self.f.array_agg(arg))
return array[self.f.safe_offset(0)]

def visit_ArrayFilter(self, op, *, arg, body, param):
Expand Down
18 changes: 15 additions & 3 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,11 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.ExtractWeekOfYear: "toISOWeek",
ops.ExtractYear: "toYear",
ops.ExtractIsoYear: "toISOYear",
ops.First: "any",
ops.IntegerRange: "range",
ops.IsInf: "isInfinite",
ops.IsNan: "isNaN",
ops.IsNull: "isNull",
ops.LStrip: "trimLeft",
ops.Last: "anyLast",
ops.Ln: "log",
ops.Log10: "log10",
ops.MapKeys: "mapKeys",
Expand Down Expand Up @@ -606,10 +604,24 @@ def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str:
def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the pyspark backend"
"`ignore_null=False` is not supported by the clickhouse backend"
)
return self.agg.groupArray(arg, where=where, order_by=order_by)

def visit_First(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the clickhouse backend"
)
return self.agg.any(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the clickhouse backend"
)
return self.agg.anyLast(arg, where=where, order_by=order_by)

def visit_CountDistinctStar(
self, op: ops.CountDistinctStar, *, where, **_: Any
) -> str:
Expand Down
14 changes: 8 additions & 6 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,16 @@ def visit_StringConcat(self, op, *, arg):
sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg)
)

def visit_First(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_First(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Aggregate(self, op, *, parent, groups, metrics):
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ class DruidCompiler(SQLGlotCompiler):
ops.DateFromYMD,
ops.DayOfWeekIndex,
ops.DayOfWeekName,
ops.First,
ops.IntervalFromInteger,
ops.IsNan,
ops.IsInf,
ops.Last,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
Expand Down
14 changes: 8 additions & 6 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,14 +510,16 @@ def visit_RegexReplace(self, op, *, arg, pattern, replacement):
arg, pattern, replacement, "g", dialect=self.dialect
)

def visit_First(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_First(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if ignore_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last(arg, where=where, order_by=order_by)

def visit_Quantile(self, op, *, arg, quantile, where):
Expand Down
16 changes: 14 additions & 2 deletions ibis/backends/sql/compilers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ class ExasolCompiler(SQLGlotCompiler):
ops.Log10: "log10",
ops.All: "min",
ops.Any: "max",
ops.First: "first_value",
ops.Last: "last_value",
}

@staticmethod
Expand Down Expand Up @@ -136,6 +134,20 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by):

return sge.GroupConcat(this=arg, separator=sep)

def visit_First(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the exasol backend"
)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the exasol backend"
)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_StartsWith(self, op, *, arg, start):
return self.f.left(arg, self.f.length(start)).eq(start)

Expand Down
16 changes: 14 additions & 2 deletions ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ArrayRemove: "array_remove",
ops.ArrayUnion: "array_union",
ops.ExtractDayOfYear: "dayofyear",
ops.First: "first_value",
ops.Last: "last_value",
ops.MapKeys: "map_keys",
ops.MapValues: "map_values",
ops.Power: "power",
Expand Down Expand Up @@ -307,6 +305,20 @@ def visit_ArraySlice(self, op, *, arg, start, stop):

return self.f.array_slice(*args)

def visit_First(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the flink backend"
)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the flink backend"
)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Not(self, op, *, arg):
return sg.not_(self.cast(arg, dt.boolean))

Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ class ImpalaCompiler(SQLGlotCompiler):
ops.Covariance,
ops.DateDelta,
ops.ExtractDayOfYear,
ops.First,
ops.Last,
ops.Levenshtein,
ops.Map,
ops.Median,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,12 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.DateDiff,
ops.DateSub,
ops.EndsWith,
ops.First,
ops.IntervalAdd,
ops.IntervalFromInteger,
ops.IntervalMultiply,
ops.IntervalSubtract,
ops.IsInf,
ops.IsNan,
ops.Last,
ops.LPad,
ops.Levenshtein,
ops.Map,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def POS_INF(self):
ops.ArrayFlatten,
ops.ArrayMap,
ops.Covariance,
ops.First,
ops.Last,
ops.Levenshtein,
ops.Median,
ops.Mode,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ class OracleCompiler(SQLGlotCompiler):
ops.ArrayFlatten,
ops.ArrayMap,
ops.ArrayStringJoin,
ops.First,
ops.Last,
ops.Mode,
ops.MultiQuantile,
ops.RegexSplit,
Expand Down
14 changes: 14 additions & 0 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,20 @@ def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_First(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the postgres backend"
)
return self.agg.first(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the postgres backend"
)
return self.agg.last(arg, where=where, order_by=order_by)

def visit_Log2(self, op, *, arg):
return self.cast(
self.f.log(
Expand Down
16 changes: 11 additions & 5 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,17 @@ def visit_FirstValue(self, op, *, arg):
def visit_LastValue(self, op, *, arg):
return sge.IgnoreNulls(this=self.f.last(arg))

def visit_First(self, op, *, arg, where, order_by):
return sge.IgnoreNulls(this=self.agg.first(arg, where=where, order_by=order_by))

def visit_Last(self, op, *, arg, where, order_by):
return sge.IgnoreNulls(this=self.agg.last(arg, where=where, order_by=order_by))
def visit_First(self, op, *, arg, where, order_by, ignore_null):
out = self.agg.first(arg, where=where, order_by=order_by)
if ignore_null:
out = sge.IgnoreNulls(this=out)
return out

def visit_Last(self, op, *, arg, where, order_by, ignore_null):
out = self.agg.last(arg, where=where, order_by=order_by)
if ignore_null:
out = sge.IgnoreNulls(this=out)
return out

def visit_Arbitrary(self, op, *, arg, where):
# For Spark>=3.4 we could use any_value here
Expand Down
12 changes: 10 additions & 2 deletions ibis/backends/sql/compilers/risingwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,22 @@ class RisingWaveCompiler(PostgresCompiler):
def visit_DateNow(self, op):
return self.cast(sge.CurrentTimestamp(), dt.date)

def visit_First(self, op, *, arg, where, order_by):
def visit_First(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the risingwave backend"
)
if not order_by:
raise com.UnsupportedOperationError(
"RisingWave requires an `order_by` be specified in `first`"
)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by):
def visit_Last(self, op, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the risingwave backend"
)
if not order_by:
raise com.UnsupportedOperationError(
"RisingWave requires an `order_by` be specified in `last`"
Expand Down
14 changes: 9 additions & 5 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9}
return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short])

def _array_collect(self, *, arg, where, order_by, ignore_null=True):
def _array_collect(self, *, arg, where, order_by, ignore_null):
if not ignore_null:
raise com.UnsupportedOperationError(
"`ignore_null=False` is not supported by the snowflake backend"
Expand All @@ -476,12 +476,16 @@ def visit_ArrayCollect(self, op, *, arg, where, order_by, ignore_null):
arg=arg, where=where, order_by=order_by, ignore_null=ignore_null
)

def visit_First(self, op, *, arg, where, order_by):
out = self._array_collect(arg=arg, where=where, order_by=order_by)
def visit_First(self, op, *, arg, where, order_by, ignore_null):
out = self._array_collect(
arg=arg, where=where, order_by=order_by, ignore_null=ignore_null
)
return self.f.get(out, 0)

def visit_Last(self, op, *, arg, where, order_by):
out = self._array_collect(arg=arg, where=where, order_by=order_by)
def visit_Last(self, op, *, arg, where, order_by, ignore_null):
out = self._array_collect(
arg=arg, where=where, order_by=order_by, ignore_null=ignore_null
)
return self.f.get(out, self.f.array_size(out) - 1)

def visit_GroupConcat(self, op, *, arg, where, sep, order_by):
Expand Down
10 changes: 8 additions & 2 deletions ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ class SQLiteCompiler(SQLGlotCompiler):
ops.BitOr: "_ibis_bit_or",
ops.BitAnd: "_ibis_bit_and",
ops.BitXor: "_ibis_bit_xor",
ops.First: "_ibis_first",
ops.Last: "_ibis_last",
ops.Mode: "_ibis_mode",
ops.Time: "time",
ops.Date: "date",
Expand Down Expand Up @@ -249,6 +247,14 @@ def visit_UnwrapJSONBoolean(self, op, *, arg):
NULL,
)

def visit_First(self, op, *, arg, where, order_by, ignore_null):
func = "_ibis_first_ignore_null" if ignore_null else "_ibis_first"
return self.agg[func](arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, ignore_null):
func = "_ibis_last_ignore_null" if ignore_null else "_ibis_last"
return self.agg[func](arg, where=where, order_by=order_by)

def visit_Variance(self, op, *, arg, how, where):
return self.agg[f"_ibis_var_{op.how}"](arg, where=where)

Expand Down
Loading

0 comments on commit 8d10a3e

Please sign in to comment.