From 734444caeed571b82272d07d5f788e6726e2b999 Mon Sep 17 00:00:00 2001
From: Deepak Cherian <deepak@cherian.net>
Date: Tue, 7 Nov 2023 14:22:15 -0700
Subject: [PATCH 1/3] Add quantile_tdigest

Co-authored-by: Florian Jetter <fjetter@users.noreply.github.com>
---
 flox/aggregations.py | 13 ++++++++++++-
 flox/sketches.py     | 35 +++++++++++++++++++++++++++++++++++
 2 files changed, 47 insertions(+), 1 deletion(-)
 create mode 100644 flox/sketches.py

diff --git a/flox/aggregations.py b/flox/aggregations.py
index b91d191b2..d6ab65f39 100644
--- a/flox/aggregations.py
+++ b/flox/aggregations.py
@@ -8,7 +8,7 @@
 import numpy as np
 from numpy.typing import DTypeLike
 
-from . import aggregate_flox, aggregate_npg, xrutils
+from . import aggregate_flox, aggregate_npg, sketches, xrutils
 from . import xrdtypes as dtypes
 
 if TYPE_CHECKING:
@@ -495,6 +495,16 @@ def _pick_second(*x):
 mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
 nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
 
+
+quantile_tdigest = Aggregation(
+    "quantile_tdigest",
+    numpy=(sketches.tdigest_aggregate,),
+    chunk=(sketches.tdigest_chunk,),
+    combine=(sketches.tdigest_combine,),
+    finalize=sketches.tdigest_aggregate,
+)
+
+
 aggregations = {
     "any": any_,
     "all": all_,
@@ -527,6 +537,7 @@ def _pick_second(*x):
     "nanquantile": nanquantile,
     "mode": mode,
     "nanmode": nanmode,
+    "quantile_tdigest": quantile_tdigest,
 }
 
 
diff --git a/flox/sketches.py b/flox/sketches.py
new file mode 100644
index 000000000..f2afa24b2
--- /dev/null
+++ b/flox/sketches.py
@@ -0,0 +1,35 @@
+import numpy as np
+import numpy_groupies as npg
+
+
+def tdigest_chunk(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, **kwargs):
+    from crick import TDigest
+
+    def _(arr):
+        digest = TDigest()
+        # we receive object arrays from numpy_groupies
+        digest.update(arr.astype(array.dtype, copy=False))
+        return digest
+
+    result = npg.aggregate_numpy.aggregate(group_idx, array, func=_, axis=axis, dtype=object)
+    return result
+
+
+def tdigest_combine(digests, axis=-1, keepdims=True):
+    from crick import TDigest
+
+    def _(arr):
+        t = TDigest()
+        t.merge(*arr)
+        return np.array([t], dtype=object)
+
+    (axis,) = axis
+    result = np.apply_along_axis(_, axis, digests)
+
+    return result
+
+
+def tdigest_aggregate(digests, q, axis=-1, keepdims=True):
+    for idx in np.ndindex(digests.shape):
+        digests[idx] = digests[idx].quantile(q)
+    return digests

From e8cb8d8bed008d5bb1f3df6abc10c9c6f6f472a7 Mon Sep 17 00:00:00 2001
From: Deepak Cherian <deepak@cherian.net>
Date: Wed, 8 Nov 2023 22:01:52 -0700
Subject: [PATCH 2/3] Fixes.

---
 flox/aggregations.py | 20 +++++++++++++++++++-
 flox/sketches.py     | 13 ++++++++++---
 2 files changed, 29 insertions(+), 4 deletions(-)

diff --git a/flox/aggregations.py b/flox/aggregations.py
index d6ab65f39..89cc7dbbf 100644
--- a/flox/aggregations.py
+++ b/flox/aggregations.py
@@ -111,7 +111,10 @@ def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -
     elif not isinstance(dtype, np.dtype):
         dtype = np.dtype(dtype)
     if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
-        dtype = np.result_type(dtype, fill_value)
+        try:
+            dtype = np.result_type(dtype, fill_value)
+        except TypeError:
+            pass
     return dtype
 
 
@@ -496,12 +499,26 @@ def _pick_second(*x):
 nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
 
 
+from crick import TDigest
+
 quantile_tdigest = Aggregation(
     "quantile_tdigest",
     numpy=(sketches.tdigest_aggregate,),
     chunk=(sketches.tdigest_chunk,),
     combine=(sketches.tdigest_combine,),
     finalize=sketches.tdigest_aggregate,
+    fill_value=TDigest(),
+    final_dtype=np.float64,
+)
+
+nanquantile_tdigest = Aggregation(
+    "nanquantile_tdigest",
+    numpy=(sketches.tdigest_aggregate,),
+    chunk=(sketches.tdigest_chunk,),
+    combine=(sketches.tdigest_combine,),
+    finalize=sketches.tdigest_aggregate,
+    fill_value=TDigest(),
+    final_dtype=np.float64,
 )
 
 
@@ -538,6 +555,7 @@ def _pick_second(*x):
     "mode": mode,
     "nanmode": nanmode,
     "quantile_tdigest": quantile_tdigest,
+    "nanquantile_tdigest": nanquantile_tdigest,
 }
 
 
diff --git a/flox/sketches.py b/flox/sketches.py
index f2afa24b2..2f70d6c55 100644
--- a/flox/sketches.py
+++ b/flox/sketches.py
@@ -11,7 +11,9 @@ def _(arr):
         digest.update(arr.astype(array.dtype, copy=False))
         return digest
 
-    result = npg.aggregate_numpy.aggregate(group_idx, array, func=_, axis=axis, dtype=object)
+    result = npg.aggregate_numpy.aggregate(
+        group_idx, array, func=_, size=size, fill_value=fill_value, axis=axis, dtype=object
+    )
     return result
 
 
@@ -23,8 +25,13 @@ def _(arr):
         t.merge(*arr)
         return np.array([t], dtype=object)
 
-    (axis,) = axis
-    result = np.apply_along_axis(_, axis, digests)
+    if not isinstance(axis, tuple):
+        axis = (axis,)
+
+    # If reducing along multiple axes, we can just keep combining ;)
+    result = digests
+    for ax in axis:
+        result = np.apply_along_axis(_, ax, result)
 
     return result
 

From 53540d453d67f4e4132e9184c3dbbee459b6ffb8 Mon Sep 17 00:00:00 2001
From: Deepak Cherian <deepak@cherian.net>
Date: Thu, 9 Nov 2023 09:56:16 -0700
Subject: [PATCH 3/3] type ignore crick

---
 pyproject.toml | 1 +
 1 file changed, 1 insertion(+)

diff --git a/pyproject.toml b/pyproject.toml
index c507a5222..ade3cd913 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -117,6 +117,7 @@ exclude=["asv_bench/pkgs"]
 module=[
     "asv_runner.*",
     "cachey",
+    "crick",
     "cftime",
     "dask.*",
     "importlib_metadata",