From fe052ca9c86a66ed2a76c2e2eb0d8890e647cf41 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 11 Sep 2023 16:17:44 -0500 Subject: [PATCH] Add optimization for Head(SortValues) -> nlargest/nsmallest It feels nice. Good follow-on work would be to support set_index. --- dask_expr/_expr.py | 2 +- dask_expr/_shuffle.py | 13 +++++++++++++ dask_expr/tests/test_shuffle.py | 18 ++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index cd36641f2..9ce285c36 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs): operands.append(kwargs.pop(parameter)) except KeyError: operands.append(type(self)._defaults[parameter]) - assert not kwargs + assert not kwargs, kwargs self.operands = operands if self._required_attribute: dep = next(iter(self.dependencies()))._meta diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index 0d9e0e4d7..c2bed3b48 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -820,6 +820,19 @@ def _lower(self): ) def _simplify_up(self, parent): + from dask_expr._expr import Head, Tail + from dask_expr._reductions import NLargest, NSmallest + + if isinstance(parent, Head): + if self.ascending: + return NSmallest(self.frame, n=parent.n, _columns=self.by) + else: + return NLargest(self.frame, n=parent.n, _columns=self.by) + if isinstance(parent, Tail): + if self.ascending: + return NLargest(self.frame, n=parent.n, _columns=self.by) + else: + return NSmallest(self.frame, n=parent.n, _columns=self.by) if isinstance(parent, Projection): parent_columns = parent.columns columns = parent_columns + [ diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index d0e92e04e..697ec1da8 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -260,3 +260,21 @@ def test_sort_values_descending(df, pdf): pdf.sort_values(by="y", ascending=False), sort_results=False, ) + + +def test_sort_head_nlargest(df): + a = df.sort_values("x", ascending=False).head(10, compute=False).expr + b = df.nlargest(10, columns=["x"]).expr + assert a.optimize()._name == b.optimize()._name + + a = df.sort_values("x", ascending=True).head(10, compute=False).expr + b = df.nsmallest(10, columns=["x"]).expr + assert a.optimize()._name == b.optimize()._name + + a = df.sort_values("x", ascending=False).tail(10, compute=False).expr + b = df.nsmallest(10, columns=["x"]).expr + assert a.optimize()._name == b.optimize()._name + + a = df.sort_values("x", ascending=True).tail(10, compute=False).expr + b = df.nlargest(10, columns=["x"]).expr + assert a.optimize()._name == b.optimize()._name