diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index b527663a6394f..8cecd30c02d16 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -30,8 +30,6 @@ class DataFusionCompiler(SQLGlotCompiler): post_rewrites = (split_select_distinct_with_order_by,) UNSUPPORTED_OPS = ( - ops.ArgMax, - ops.ArgMin, ops.ArrayDistinct, ops.ArrayFilter, ops.ArrayMap, @@ -457,6 +455,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null): where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last_value(arg, where=where, order_by=order_by) + def visit_ArgMin(self, op, *, arg, key, where): + return self.agg.first_value(arg, where=where, order_by=[sge.Ordered(this=key)]) + + def visit_ArgMax(self, op, *, arg, key, where): + return self.agg.first_value( + arg, where=where, order_by=[sge.Ordered(this=key, desc=True)] + ) + def visit_Aggregate(self, op, *, parent, groups, metrics): """Support `GROUP BY` expressions in `SELECT` since DataFusion does not.""" quoted = self.quoted diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index a8a4b95ca1cc5..2ff92c14f3619 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -123,7 +123,6 @@ def mean_udf(s): ] argidx_not_grouped_marks = [ - "datafusion", "impala", "mysql", "mssql", @@ -411,7 +410,6 @@ def mean_and_std(v): [ "impala", "mysql", - "datafusion", "mssql", "druid", "oracle", @@ -431,7 +429,6 @@ def mean_and_std(v): [ "impala", "mysql", - "datafusion", "mssql", "druid", "oracle", @@ -691,7 +688,6 @@ def test_first_last_ordered(alltypes, method, filtered, include_null): @pytest.mark.notimpl( [ - "datafusion", "druid", "exasol", "flink",