Skip to content

Commit

Permalink
Add dask-cudf workaround for missing rename_axis support in cudf (#…
Browse files Browse the repository at this point in the history
…16899)

See #16895
Closes #16892

Dask-expr uses `rename_axis`, which is not supported by cudf yet. This is a temporary workaround until #16895 is resolved.

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Mads R. B. Kristensen (https://github.com/madsbk)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #16899
  • Loading branch information
rjzamora authored Sep 25, 2024
1 parent dbe5528 commit 75c5c83
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
12 changes: 12 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from dask import config
from dask.dataframe.core import is_dataframe_like
from dask.typing import no_default

import cudf

Expand Down Expand Up @@ -90,6 +91,17 @@ def var(
)
)

def rename_axis(
self, mapper=no_default, index=no_default, columns=no_default, axis=0
):
from dask_cudf.expr._expr import RenameAxisCudf

return new_collection(
RenameAxisCudf(
self, mapper=mapper, index=index, columns=columns, axis=axis
)
)


class DataFrame(DXDataFrame, CudfFrameBase):
@classmethod
Expand Down
16 changes: 15 additions & 1 deletion python/dask_cudf/dask_cudf/expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import dask_expr._shuffle as _shuffle_module
from dask_expr import new_collection
from dask_expr._cumulative import CumulativeBlockwise
from dask_expr._expr import Elemwise, Expr, VarColumns
from dask_expr._expr import Elemwise, Expr, RenameAxis, VarColumns
from dask_expr._reductions import Reduction, Var

from dask.dataframe.core import is_dataframe_like, make_meta, meta_nonempty
from dask.dataframe.dispatch import is_categorical_dtype
from dask.typing import no_default

import cudf

Expand All @@ -17,6 +18,19 @@
##


class RenameAxisCudf(RenameAxis):
# TODO: Remove this after rename_axis is supported in cudf
# (See: https://github.com/rapidsai/cudf/issues/16895)
@staticmethod
def operation(df, index=no_default, **kwargs):
if index != no_default:
df.index.name = index
return df
raise NotImplementedError(
"Only `index` is supported for the cudf backend"
)


class ToCudfBackend(Elemwise):
# TODO: Inherit from ToBackend when rapids-dask-dependency
# is pinned to dask>=2024.8.1
Expand Down
12 changes: 12 additions & 0 deletions python/dask_cudf/dask_cudf/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,3 +1027,15 @@ def test_cov_corr(op, numeric_only):
# (See: https://github.com/rapidsai/cudf/issues/12626)
expect = getattr(df.to_pandas(), op)(numeric_only=numeric_only)
dd.assert_eq(res, expect)


def test_rename_axis_after_join():
df1 = cudf.DataFrame(index=["a", "b", "c"], data=dict(a=[1, 2, 3]))
df1.index.name = "test"
ddf1 = dd.from_pandas(df1, 2)

df2 = cudf.DataFrame(index=["a", "b", "d"], data=dict(b=[1, 2, 3]))
ddf2 = dd.from_pandas(df2, 2)
result = ddf1.join(ddf2, how="outer")
expected = df1.join(df2, how="outer")
dd.assert_eq(result, expected, check_index=False)

0 comments on commit 75c5c83

Please sign in to comment.