Skip to content

Commit

Permalink
Add optimization for Head(SortValues) -> nlargest/nsmallest
Browse files Browse the repository at this point in the history
It feels nice.  Good follow-on work would be to support set_index.
  • Loading branch information
mrocklin committed Sep 13, 2023
1 parent b485668 commit fe052ca
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs):
operands.append(kwargs.pop(parameter))
except KeyError:
operands.append(type(self)._defaults[parameter])
assert not kwargs
assert not kwargs, kwargs
self.operands = operands
if self._required_attribute:
dep = next(iter(self.dependencies()))._meta
Expand Down
13 changes: 13 additions & 0 deletions dask_expr/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,19 @@ def _lower(self):
)

def _simplify_up(self, parent):
from dask_expr._expr import Head, Tail
from dask_expr._reductions import NLargest, NSmallest

if isinstance(parent, Head):
if self.ascending:
return NSmallest(self.frame, n=parent.n, _columns=self.by)
else:
return NLargest(self.frame, n=parent.n, _columns=self.by)
if isinstance(parent, Tail):
if self.ascending:
return NLargest(self.frame, n=parent.n, _columns=self.by)
else:
return NSmallest(self.frame, n=parent.n, _columns=self.by)
if isinstance(parent, Projection):
parent_columns = parent.columns
columns = parent_columns + [
Expand Down
18 changes: 18 additions & 0 deletions dask_expr/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,21 @@ def test_sort_values_descending(df, pdf):
pdf.sort_values(by="y", ascending=False),
sort_results=False,
)


def test_sort_head_nlargest(df):
a = df.sort_values("x", ascending=False).head(10, compute=False).expr
b = df.nlargest(10, columns=["x"]).expr
assert a.optimize()._name == b.optimize()._name

a = df.sort_values("x", ascending=True).head(10, compute=False).expr
b = df.nsmallest(10, columns=["x"]).expr
assert a.optimize()._name == b.optimize()._name

a = df.sort_values("x", ascending=False).tail(10, compute=False).expr
b = df.nsmallest(10, columns=["x"]).expr
assert a.optimize()._name == b.optimize()._name

a = df.sort_values("x", ascending=True).tail(10, compute=False).expr
b = df.nlargest(10, columns=["x"]).expr
assert a.optimize()._name == b.optimize()._name

0 comments on commit fe052ca

Please sign in to comment.