From 3310220a72524faaedaf172371a23e3dec127987 Mon Sep 17 00:00:00 2001
From: Jim Crist-Harif <jcristharif@gmail.com>
Date: Wed, 25 Sep 2024 11:21:14 -0500
Subject: [PATCH] feat(datafusion): implement `argmin`/`argmax`

---
 ibis/backends/sql/compilers/datafusion.py | 10 ++++++++--
 ibis/backends/tests/test_aggregation.py   |  4 ----
 2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py
index b527663a6394f..8cecd30c02d16 100644
--- a/ibis/backends/sql/compilers/datafusion.py
+++ b/ibis/backends/sql/compilers/datafusion.py
@@ -30,8 +30,6 @@ class DataFusionCompiler(SQLGlotCompiler):
     post_rewrites = (split_select_distinct_with_order_by,)
 
     UNSUPPORTED_OPS = (
-        ops.ArgMax,
-        ops.ArgMin,
         ops.ArrayDistinct,
         ops.ArrayFilter,
         ops.ArrayMap,
@@ -457,6 +455,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
             where = cond if where is None else sge.And(this=cond, expression=where)
         return self.agg.last_value(arg, where=where, order_by=order_by)
 
+    def visit_ArgMin(self, op, *, arg, key, where):
+        return self.agg.first_value(arg, where=where, order_by=[sge.Ordered(this=key)])
+
+    def visit_ArgMax(self, op, *, arg, key, where):
+        return self.agg.first_value(
+            arg, where=where, order_by=[sge.Ordered(this=key, desc=True)]
+        )
+
     def visit_Aggregate(self, op, *, parent, groups, metrics):
         """Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
         quoted = self.quoted
diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py
index a8a4b95ca1cc5..2ff92c14f3619 100644
--- a/ibis/backends/tests/test_aggregation.py
+++ b/ibis/backends/tests/test_aggregation.py
@@ -123,7 +123,6 @@ def mean_udf(s):
 ]
 
 argidx_not_grouped_marks = [
-    "datafusion",
     "impala",
     "mysql",
     "mssql",
@@ -411,7 +410,6 @@ def mean_and_std(v):
                     [
                         "impala",
                         "mysql",
-                        "datafusion",
                         "mssql",
                         "druid",
                         "oracle",
@@ -431,7 +429,6 @@ def mean_and_std(v):
                     [
                         "impala",
                         "mysql",
-                        "datafusion",
                         "mssql",
                         "druid",
                         "oracle",
@@ -691,7 +688,6 @@ def test_first_last_ordered(alltypes, method, filtered, include_null):
 
 @pytest.mark.notimpl(
     [
-        "datafusion",
         "druid",
         "exasol",
         "flink",