From 3310220a72524faaedaf172371a23e3dec127987 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif <jcristharif@gmail.com> Date: Wed, 25 Sep 2024 11:21:14 -0500 Subject: [PATCH] feat(datafusion): implement `argmin`/`argmax` --- ibis/backends/sql/compilers/datafusion.py | 10 ++++++++-- ibis/backends/tests/test_aggregation.py | 4 ---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index b527663a6394f..8cecd30c02d16 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -30,8 +30,6 @@ class DataFusionCompiler(SQLGlotCompiler): post_rewrites = (split_select_distinct_with_order_by,) UNSUPPORTED_OPS = ( - ops.ArgMax, - ops.ArgMin, ops.ArrayDistinct, ops.ArrayFilter, ops.ArrayMap, @@ -457,6 +455,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null): where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last_value(arg, where=where, order_by=order_by) + def visit_ArgMin(self, op, *, arg, key, where): + return self.agg.first_value(arg, where=where, order_by=[sge.Ordered(this=key)]) + + def visit_ArgMax(self, op, *, arg, key, where): + return self.agg.first_value( + arg, where=where, order_by=[sge.Ordered(this=key, desc=True)] + ) + def visit_Aggregate(self, op, *, parent, groups, metrics): """Support `GROUP BY` expressions in `SELECT` since DataFusion does not.""" quoted = self.quoted diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index a8a4b95ca1cc5..2ff92c14f3619 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -123,7 +123,6 @@ def mean_udf(s): ] argidx_not_grouped_marks = [ - "datafusion", "impala", "mysql", "mssql", @@ -411,7 +410,6 @@ def mean_and_std(v): [ "impala", "mysql", - "datafusion", "mssql", "druid", "oracle", @@ -431,7 +429,6 @@ def mean_and_std(v): [ "impala", "mysql", - "datafusion", "mssql", "druid", "oracle", @@ -691,7 +688,6 @@ def test_first_last_ordered(alltypes, method, filtered, include_null): @pytest.mark.notimpl( [ - "datafusion", "druid", "exasol", "flink",