Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix metadata after implicit array conversion from Dask cuDF #16842

Merged
merged 11 commits into from
Sep 25, 2024
79 changes: 55 additions & 24 deletions python/dask_cudf/dask_cudf/expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,58 @@ class Index(DXIndex, CudfFrameBase):
##


try:
from dask_expr._backends import create_array_collection

@get_collection_type.register_lazy("cupy")
def _register_cupy():
import cupy

@get_collection_type.register(cupy.ndarray)
def get_collection_type_cupy_array(_):
return create_array_collection

@get_collection_type.register_lazy("cupyx")
def _register_cupyx():
# Needed for cuml
from cupyx.scipy.sparse import spmatrix

@get_collection_type.register(spmatrix)
def get_collection_type_csr_matrix(_):
return create_array_collection

except ImportError:
# Older version of dask-expr.
# Implicit conversion to array wont work.
pass
def _create_array_collection_with_meta(expr):
# NOTE: This is the GPU compatible version of
# `new_dd_object` for DataFrame -> Array conversion.
# This can be removed if dask#11017 is resolved
# (See: https://github.com/dask/dask/issues/11017)
import numpy as np

import dask.array as da
from dask.blockwise import Blockwise
from dask.highlevelgraph import HighLevelGraph

result = expr.optimize()
dsk = result.__dask_graph__()
name = result._name
meta = result._meta
divisions = result.divisions
chunks = ((np.nan,) * (len(divisions) - 1),) + tuple(
(d,) for d in meta.shape[1:]
)
if len(chunks) > 1:
if isinstance(dsk, HighLevelGraph):
layer = dsk.layers[name]
else:
# dask-expr provides a dict only
layer = dsk
if isinstance(layer, Blockwise):
layer.new_axes["j"] = chunks[1][0]
layer.output_indices = layer.output_indices + ("j",)
else:
suffix = (0,) * (len(chunks) - 1)
for i in range(len(chunks[0])):
layer[(name, i) + suffix] = layer.pop((name, i))

return da.Array(dsk, name=name, chunks=chunks, meta=meta)


@get_collection_type.register_lazy("cupy")
def _register_cupy():
import cupy

get_collection_type.register(
cupy.ndarray,
lambda x: _create_array_collection_with_meta,
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
)


@get_collection_type.register_lazy("cupyx")
def _register_cupyx():
# Needed for cuml
from cupyx.scipy.sparse import spmatrix

get_collection_type.register(
spmatrix,
lambda x: _create_array_collection_with_meta,
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
)
17 changes: 10 additions & 7 deletions python/dask_cudf/dask_cudf/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import dask_cudf
from dask_cudf.tests.utils import (
QUERY_PLANNING_ON,
require_dask_expr,
skip_dask_expr,
xfail_dask_expr,
Expand Down Expand Up @@ -950,12 +951,16 @@ def test_implicit_array_conversion_cupy():
def func(x):
return x.values

# Need to compute the dask collection for now.
# See: https://github.com/dask/dask/issues/11017
result = ds.map_partitions(func, meta=s.values).compute()
expect = func(s)
result = ds.map_partitions(func, meta=s.values)

dask.array.assert_eq(result, expect)
if QUERY_PLANNING_ON:
# Check Array and round-tripped DataFrame
dask.array.assert_eq(result, func(s))
dd.assert_eq(result.to_dask_dataframe(), s, check_index=False)
else:
# Legacy version still carries numpy metadata
# See: https://github.com/dask/dask/issues/11017
dask.array.assert_eq(result.compute(), func(s))


def test_implicit_array_conversion_cupy_sparse():
Expand All @@ -967,8 +972,6 @@ def test_implicit_array_conversion_cupy_sparse():
def func(x):
return cupyx.scipy.sparse.csr_matrix(x.values)

# Need to compute the dask collection for now.
# See: https://github.com/dask/dask/issues/11017
result = ds.map_partitions(func, meta=s.values).compute()
expect = func(s)

Expand Down
Loading