diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index d7c1a4e0..9518d0d8 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1955,6 +1955,19 @@ def __init__(self, the_fn, arg_repackers, kwarg_repacker, arg_lens_for_repackers self.kwarg_repacker = kwarg_repacker self.arg_lens_for_repackers = arg_lens_for_repackers + def _repack(self, *args_deps_expanded): + args = [] + len_args = 0 + for repacker, n_args in zip(self.arg_repackers, self.arg_lens_for_repackers): + args.append( + repacker(args_deps_expanded[len_args : len_args + n_args])[0] + if repacker is not None + else args_deps_expanded[len_args] + ) + len_args += n_args + kwargs = self.kwarg_repacker(args_deps_expanded[len_args:])[0] + return args, kwargs + def __call__(self, *args_deps_expanded): """This packing function receives a list of strictly ordered arguments. The first range of arguments, @@ -1969,16 +1982,7 @@ def __call__(self, *args_deps_expanded): The various repackers deal with restructuring the received flattened list into the shape that self.fn expects. """ - args = [] - len_args = 0 - for repacker, n_args in zip(self.arg_repackers, self.arg_lens_for_repackers): - args.append( - repacker(args_deps_expanded[len_args : len_args + n_args])[0] - if repacker is not None - else args_deps_expanded[len_args] - ) - len_args += n_args - kwargs = self.kwarg_repacker(args_deps_expanded[len_args:])[0] + args, kwargs = self._repack(*args_deps_expanded) return self.fn(*args, **kwargs) @@ -2000,7 +2004,12 @@ def _map_partitions( will not be traversed to extract all dask collections, except those in the first dimension of args or kwargs. """ - token = token or tokenize(fn, *args, output_divisions, **kwargs) + if isinstance(fn, ArgsKwargsPackedFunction): + token_args, token_kwargs = fn._repack(*args) + token = token or tokenize(fn.fn, *token_args, output_divisions, **token_kwargs) + else: + token = token or tokenize(fn, *args, output_divisions, **kwargs) + label = hyphenize(label or funcname(fn)) name = f"{label}-{token}" deps = [a for a in args if is_dask_collection(a)] + [