Skip to content

Commit

Permalink
mapfilter: raise NotImplementedError for non-akward arrays or datafra…
Browse files Browse the repository at this point in the history
…me-like return types
  • Loading branch information
pfackeldey committed Nov 20, 2024
1 parent 1026471 commit 836b24b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 54 deletions.
72 changes: 31 additions & 41 deletions src/dask_awkward/lib/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,36 +23,38 @@ def _single_return_map_partitions(
meta: tp.Any,
npartitions: int,
) -> tp.Any:
from dask.utils import (
is_arraylike,
is_dataframe_like,
is_index_like,
is_series_like,
)

# ak.Array (this is dak.map_partitions case)
if isinstance(meta, ak.Array):
# convert to typetracer if not already
# this happens when the user provides a concrete array (e.g. np.array)
# and then wraps it with ak.Array as a return type
if not ak.backend(meta) == "typetracer":
meta = ak.to_backend(meta, "typetracer")
return new_array_object(
hlg,
name=name,
meta=meta,
npartitions=npartitions,
)

# TODO: np.array
# from dask.utils import is_arraylike, is_dataframe_like, is_index_like, is_series_like
#
# elif is_arraylike(meta):
# this doesn't work yet, because the graph/chunking is not correct
#
# import numpy as np
# from dask.array.core import new_da_object
# meta = meta[None, ...]
# first = (np.nan,) * npartitions
# rest = ((-1,),) * (meta.ndim - 1)
# chunks = (first, *rest)
# return new_da_object(hlg, name=name, meta=meta, chunks=chunks)

# TODO: dataframe, series, index
# elif (
# is_dataframe_like(meta)
# or is_series_like(meta)
# or is_index_like(meta)
# ): pass

# TODO: array, dataframe, series, index
elif (
is_arraylike(meta)
or is_dataframe_like(meta)
or is_series_like(meta)
or is_index_like(meta)
):
msg = (
f"{meta=} is not (yet) supported as return type. If possible, "
"you can convert it to ak.Array first, or wrap it with a python container."
)
raise NotImplementedError(msg)
# don't know? -> put it in a bag
else:
from dask.bag.core import Bag
Expand Down Expand Up @@ -101,26 +103,14 @@ def _multi_return_map_partitions(
dependencies=[cast(DaskCollection, tmp)],
)
dak_cache[ith_name] = hlg_pick, m_pick

# nested return case -> recurse
if isinstance(m_pick, tuple):
ret.append(
_multi_return_map_partitions(
hlg=hlg_pick,
name=ith_name,
meta=m_pick,
npartitions=npartitions,
)
)
else:
ret.append(
_single_return_map_partitions(
hlg=hlg_pick,
name=ith_name,
meta=m_pick,
npartitions=npartitions,
)
ret.append(
_single_return_map_partitions(
hlg=hlg_pick,
name=ith_name,
meta=m_pick,
npartitions=npartitions,
)
)
return tuple(ret)


Expand Down
20 changes: 7 additions & 13 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,19 @@ class some: ...
@dak.mapfilter
def fun(x):
y = x.foo + 1
return (
y,
ak.sum(y),
some(),
np.ones((1, 4)),
) # add first length-1 dimension to numpy array for 'correct' stacking
return y, (np.sum(y),), some(), ak.Array(np.ones(4))

y, y_sum, something, np_arr = fun(dak_array)
y, y_sum, something, arr = fun(dak_array)

assert ak.all(y.compute() == ak_array.foo + 1)
assert ak.all(y_sum.compute() == np.array([5, 9]))
assert np.all(y_sum.compute() == [np.array(5), np.array(9)])
something = something.compute()
assert len(something) == 2
assert all(isinstance(s, some) for s in something)
np_arrays = np_arr.compute()
assert len(np_arrays) == 2
for arr in np_arrays:
assert arr.shape == (4,)
assert np.all(arr == np.ones(4))
array = arr.compute()
assert len(array) == 8
assert array.ndim == 1
assert ak.all(array == ak.Array(np.ones(8)))


def test_mapfilter_needs_outlike():
Expand Down

0 comments on commit 836b24b

Please sign in to comment.