Skip to content

Commit

Permalink
add mapfilter decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Nov 19, 2024
1 parent fcae1a2 commit 4213659
Show file tree
Hide file tree
Showing 6 changed files with 542 additions and 70 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,6 @@ venv.bak/

# mypy
.mypy_cache/

# pyright lsp
pyrightconfig.json
1 change: 1 addition & 0 deletions src/dask_awkward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
map_partitions,
partition_compatibility,
)
from dask_awkward.lib.decorator import mapfilter
from dask_awkward.lib.describe import backend, fields
from dask_awkward.lib.inspect import (
report_necessary_buffers,
Expand Down
1 change: 1 addition & 0 deletions src/dask_awkward/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
map_partitions,
partition_compatibility,
)
from dask_awkward.lib.decorator import mapfilter
from dask_awkward.lib.describe import backend, fields
from dask_awkward.lib.inspect import (
report_necessary_buffers,
Expand Down
179 changes: 109 additions & 70 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,31 +1982,35 @@ def __call__(self, *args_deps_expanded):
return self.fn(*args, **kwargs)


def _map_partitions(
def _new_dak_array_divisions(
dak_array: Array, output_divisions: int | None = None
) -> tuple:
in_divisions = dak_array.divisions
if output_divisions is not None:
if output_divisions == 1:
new_divisions = dak_array.divisions
else:
new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) # type: ignore[operator]
else:
new_divisions = in_divisions
return new_divisions


def _map_partitions_prepare(
fn: Callable,
*args: Any,
label: str | None = None,
token: str | None = None,
meta: Any | None = None,
output_divisions: int | None = None,
**kwargs: Any,
) -> Array:
"""Map a callable across all partitions of any number of collections.
No wrapper is used to flatten the function arguments. This is meant for
dask-awkward internal use or in situations where input data are sanitized.
The parameters of this function are otherwise the same as map_partitions,
but the limitation that args, kwargs must be non-nested and flat. They
will not be traversed to extract all dask collections, except those in
the first dimension of args or kwargs.
"""
) -> tuple:
token = token or tokenize(fn, *args, output_divisions, **kwargs)
label = hyphenize(label or funcname(fn))
name = f"{label}-{token}"
deps = [a for a in args if is_dask_collection(a)] + [
v for v in kwargs.values() if is_dask_collection(v)
]
dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps))

if name in dak_cache:
hlg, meta = dak_cache[name]
Expand All @@ -2027,22 +2031,46 @@ def _map_partitions(
dependencies=deps,
)

if len(dak_arrays) == 0:
raise TypeError(
"at least one argument passed to map_partitions "
"should be a dask_awkward.Array collection."
)
dak_cache[name] = hlg, meta
in_npartitions = dak_arrays[0].npartitions
in_divisions = dak_arrays[0].divisions
if output_divisions is not None:
if output_divisions == 1:
new_divisions = dak_arrays[0].divisions
else:
new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions))
else:
new_divisions = in_divisions
return hlg, meta, deps, name


def _map_partitions(
fn: Callable,
*args: Any,
label: str | None = None,
token: str | None = None,
meta: Any | None = None,
output_divisions: int | None = None,
**kwargs: Any,
) -> Array:
"""Map a callable across all partitions of any number of collections.
No wrapper is used to flatten the function arguments. This is meant for
dask-awkward internal use or in situations where input data are sanitized.
The parameters of this function are otherwise the same as map_partitions,
but the limitation that args, kwargs must be non-nested and flat. They
will not be traversed to extract all dask collections, except those in
the first dimension of args or kwargs.
"""
hlg, meta, deps, name = _map_partitions_prepare(
fn,
*args,
label=label,
token=token,
meta=meta,
output_divisions=output_divisions,
**kwargs,
)
dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps))
if len(dak_arrays) == 0:
raise TypeError(
"at least one argument passed to map_partitions "
"should be a dask_awkward.Array collection."
)
first = dak_arrays[0]
new_divisions = _new_dak_array_divisions(first, output_divisions)
# from IPython import embed;embed()
if output_divisions is not None:
return new_array_object(
hlg,
Expand All @@ -2055,10 +2083,62 @@ def _map_partitions(
hlg,
name=name,
meta=meta,
npartitions=in_npartitions,
npartitions=first.npartitions,
)


def _to_packed_fn_args(
base_fn: Callable,
*args: Any,
traverse: bool = True,
**kwargs: Any,
) -> tuple:
opt_touch_all = kwargs.pop("opt_touch_all", None)
if opt_touch_all is not None:
warnings.warn(
"The opt_touch_all argument does nothing.\n"
"This warning will be removed in a future version of dask-awkward "
"and the function call will likely fail."
)

kwarg_flat_deps, kwarg_repacker = unpack_collections(kwargs, traverse=traverse)
flat_deps, _ = unpack_collections(*args, *kwargs.values(), traverse=traverse)

if len(flat_deps) == 0:
message = (
"map_partitions expects at least one Dask collection instance, "
"you are passing non-Dask collections to dask-awkward code.\n"
"observed argument types:\n"
)
for arg in args:
message += f"- {type(arg)}"
raise TypeError(message)

arg_flat_deps_expanded = []
arg_repackers = []
arg_lens_for_repackers = []
for arg in args:
this_arg_flat_deps, repacker = unpack_collections(arg, traverse=traverse)
if (
len(this_arg_flat_deps) > 0
): # if the deps list is empty this arg does not contain any dask collection, no need to repack!
arg_flat_deps_expanded.extend(this_arg_flat_deps)
arg_repackers.append(repacker)
arg_lens_for_repackers.append(len(this_arg_flat_deps))
else:
arg_flat_deps_expanded.append(arg)
arg_repackers.append(None)
arg_lens_for_repackers.append(1)

packed_fn = ArgsKwargsPackedFunction(
base_fn,
arg_repackers,
kwarg_repacker,
arg_lens_for_repackers,
)
return packed_fn, arg_flat_deps_expanded, kwarg_flat_deps


def map_partitions(
base_fn: Callable,
*args: Any,
Expand Down Expand Up @@ -2139,49 +2219,8 @@ def map_partitions(
This is effectively the same as `d = c * a`
"""

opt_touch_all = kwargs.pop("opt_touch_all", None)
if opt_touch_all is not None:
warnings.warn(
"The opt_touch_all argument does nothing.\n"
"This warning will be removed in a future version of dask-awkward "
"and the function call will likely fail."
)

kwarg_flat_deps, kwarg_repacker = unpack_collections(kwargs, traverse=traverse)
flat_deps, _ = unpack_collections(*args, *kwargs.values(), traverse=traverse)

if len(flat_deps) == 0:
message = (
"map_partitions expects at least one Dask collection instance, "
"you are passing non-Dask collections to dask-awkward code.\n"
"observed argument types:\n"
)
for arg in args:
message += f"- {type(arg)}"
raise TypeError(message)

arg_flat_deps_expanded = []
arg_repackers = []
arg_lens_for_repackers = []
for arg in args:
this_arg_flat_deps, repacker = unpack_collections(arg, traverse=traverse)
if (
len(this_arg_flat_deps) > 0
): # if the deps list is empty this arg does not contain any dask collection, no need to repack!
arg_flat_deps_expanded.extend(this_arg_flat_deps)
arg_repackers.append(repacker)
arg_lens_for_repackers.append(len(this_arg_flat_deps))
else:
arg_flat_deps_expanded.append(arg)
arg_repackers.append(None)
arg_lens_for_repackers.append(1)

fn = ArgsKwargsPackedFunction(
base_fn,
arg_repackers,
kwarg_repacker,
arg_lens_for_repackers,
fn, arg_flat_deps_expanded, kwarg_flat_deps = _to_packed_fn_args(
base_fn, *args, traverse=traverse, **kwargs
)
return _map_partitions(
fn,
Expand Down
Loading

0 comments on commit 4213659

Please sign in to comment.