diff --git a/src/dask_awkward/lib/decorator.py b/src/dask_awkward/lib/decorator.py index aa04d83f..4f239146 100644 --- a/src/dask_awkward/lib/decorator.py +++ b/src/dask_awkward/lib/decorator.py @@ -23,36 +23,38 @@ def _single_return_map_partitions( meta: tp.Any, npartitions: int, ) -> tp.Any: + from dask.utils import ( + is_arraylike, + is_dataframe_like, + is_index_like, + is_series_like, + ) + # ak.Array (this is dak.map_partitions case) if isinstance(meta, ak.Array): + # convert to typetracer if not already + # this happens when the user provides a concrete array (e.g. np.array) + # and then wraps it with ak.Array as a return type + if not ak.backend(meta) == "typetracer": + meta = ak.to_backend(meta, "typetracer") return new_array_object( hlg, name=name, meta=meta, npartitions=npartitions, ) - - # TODO: np.array - # from dask.utils import is_arraylike, is_dataframe_like, is_index_like, is_series_like - # - # elif is_arraylike(meta): - # this doesn't work yet, because the graph/chunking is not correct - # - # import numpy as np - # from dask.array.core import new_da_object - # meta = meta[None, ...] - # first = (np.nan,) * npartitions - # rest = ((-1,),) * (meta.ndim - 1) - # chunks = (first, *rest) - # return new_da_object(hlg, name=name, meta=meta, chunks=chunks) - - # TODO: dataframe, series, index - # elif ( - # is_dataframe_like(meta) - # or is_series_like(meta) - # or is_index_like(meta) - # ): pass - + # TODO: array, dataframe, series, index + elif ( + is_arraylike(meta) + or is_dataframe_like(meta) + or is_series_like(meta) + or is_index_like(meta) + ): + msg = ( + f"{meta=} is not (yet) supported as return type. If possible, " + "you can convert it to ak.Array first, or wrap it with a python container." + ) + raise NotImplementedError(msg) # don't know? -> put it in a bag else: from dask.bag.core import Bag @@ -101,26 +103,14 @@ def _multi_return_map_partitions( dependencies=[cast(DaskCollection, tmp)], ) dak_cache[ith_name] = hlg_pick, m_pick - - # nested return case -> recurse - if isinstance(m_pick, tuple): - ret.append( - _multi_return_map_partitions( - hlg=hlg_pick, - name=ith_name, - meta=m_pick, - npartitions=npartitions, - ) - ) - else: - ret.append( - _single_return_map_partitions( - hlg=hlg_pick, - name=ith_name, - meta=m_pick, - npartitions=npartitions, - ) + ret.append( + _single_return_map_partitions( + hlg=hlg_pick, + name=ith_name, + meta=m_pick, + npartitions=npartitions, ) + ) return tuple(ret) diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 2e279c46..0a30e2e9 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -31,25 +31,19 @@ class some: ... @dak.mapfilter def fun(x): y = x.foo + 1 - return ( - y, - ak.sum(y), - some(), - np.ones((1, 4)), - ) # add first length-1 dimension to numpy array for 'correct' stacking + return y, (np.sum(y),), some(), ak.Array(np.ones(4)) - y, y_sum, something, np_arr = fun(dak_array) + y, y_sum, something, arr = fun(dak_array) assert ak.all(y.compute() == ak_array.foo + 1) - assert ak.all(y_sum.compute() == np.array([5, 9])) + assert np.all(y_sum.compute() == [np.array(5), np.array(9)]) something = something.compute() assert len(something) == 2 assert all(isinstance(s, some) for s in something) - np_arrays = np_arr.compute() - assert len(np_arrays) == 2 - for arr in np_arrays: - assert arr.shape == (4,) - assert np.all(arr == np.ones(4)) + array = arr.compute() + assert len(array) == 8 + assert array.ndim == 1 + assert ak.all(array == ak.Array(np.ones(8))) def test_mapfilter_needs_outlike():