From ffe335dfc4014c92fe5fd797b3716d760c556ab5 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 24 Sep 2024 13:26:45 -0700 Subject: [PATCH] add workaround for missing rename_axis support --- python/dask_cudf/dask_cudf/expr/_collection.py | 12 ++++++++++++ python/dask_cudf/dask_cudf/expr/_expr.py | 16 +++++++++++++++- python/dask_cudf/dask_cudf/tests/test_core.py | 12 ++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index 97e1dffc65b..4a22a86afa0 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -15,6 +15,7 @@ from dask import config from dask.dataframe.core import is_dataframe_like +from dask.typing import no_default import cudf @@ -90,6 +91,17 @@ def var( ) ) + def rename_axis( + self, mapper=no_default, index=no_default, columns=no_default, axis=0 + ): + from dask_cudf.expr._expr import RenameAxisCudf + + return new_collection( + RenameAxisCudf( + self, mapper=mapper, index=index, columns=columns, axis=axis + ) + ) + class DataFrame(DXDataFrame, CudfFrameBase): @classmethod diff --git a/python/dask_cudf/dask_cudf/expr/_expr.py b/python/dask_cudf/dask_cudf/expr/_expr.py index 8a2c50d3fe7..b284ab3774d 100644 --- a/python/dask_cudf/dask_cudf/expr/_expr.py +++ b/python/dask_cudf/dask_cudf/expr/_expr.py @@ -4,11 +4,12 @@ import dask_expr._shuffle as _shuffle_module from dask_expr import new_collection from dask_expr._cumulative import CumulativeBlockwise -from dask_expr._expr import Elemwise, Expr, VarColumns +from dask_expr._expr import Elemwise, Expr, RenameAxis, VarColumns from dask_expr._reductions import Reduction, Var from dask.dataframe.core import is_dataframe_like, make_meta, meta_nonempty from dask.dataframe.dispatch import is_categorical_dtype +from dask.typing import no_default import cudf @@ -17,6 +18,19 @@ ## +class RenameAxisCudf(RenameAxis): + # TODO: Remove this after rename_axis is supported in cudf + # (See: https://github.com/rapidsai/cudf/issues/16895) + @staticmethod + def operation(df, index=no_default, **kwargs): + if index != no_default: + df.index.name = index + return df + raise NotImplementedError( + "Only `index` is supported for the cudf backend" + ) + + class ToCudfBackend(Elemwise): # TODO: Inherit from ToBackend when rapids-dask-dependency # is pinned to dask>=2024.8.1 diff --git a/python/dask_cudf/dask_cudf/tests/test_core.py b/python/dask_cudf/dask_cudf/tests/test_core.py index 7aa0f6320f2..0e86835524c 100644 --- a/python/dask_cudf/dask_cudf/tests/test_core.py +++ b/python/dask_cudf/dask_cudf/tests/test_core.py @@ -1024,3 +1024,15 @@ def test_cov_corr(op, numeric_only): # (See: https://github.com/rapidsai/cudf/issues/12626) expect = getattr(df.to_pandas(), op)(numeric_only=numeric_only) dd.assert_eq(res, expect) + + +def test_rename_axis_after_join(): + df1 = cudf.DataFrame(index=["a", "b", "c"], data=dict(a=[1, 2, 3])) + df1.index.name = "test" + ddf1 = dd.from_pandas(df1, 2) + + df2 = cudf.DataFrame(index=["a", "b", "d"], data=dict(b=[1, 2, 3])) + ddf2 = dd.from_pandas(df2, 2) + result = ddf1.join(ddf2, how="outer") + expected = df1.join(df2, how="outer") + dd.assert_eq(result, expected, check_index=False)