Skip to content

Commit

Permalink
fix tokenization of ArgsKwargsPackedFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Dec 3, 2024
1 parent 1d4d4e9 commit 49b836f
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,6 +1955,19 @@ def __init__(self, the_fn, arg_repackers, kwarg_repacker, arg_lens_for_repackers
self.kwarg_repacker = kwarg_repacker
self.arg_lens_for_repackers = arg_lens_for_repackers

def _repack(self, *args_deps_expanded):
args = []
len_args = 0
for repacker, n_args in zip(self.arg_repackers, self.arg_lens_for_repackers):
args.append(
repacker(args_deps_expanded[len_args : len_args + n_args])[0]
if repacker is not None
else args_deps_expanded[len_args]
)
len_args += n_args
kwargs = self.kwarg_repacker(args_deps_expanded[len_args:])[0]
return args, kwargs

def __call__(self, *args_deps_expanded):
"""This packing function receives a list of strictly
ordered arguments. The first range of arguments,
Expand All @@ -1969,16 +1982,7 @@ def __call__(self, *args_deps_expanded):
The various repackers deal with restructuring the received flattened
list into the shape that self.fn expects.
"""
args = []
len_args = 0
for repacker, n_args in zip(self.arg_repackers, self.arg_lens_for_repackers):
args.append(
repacker(args_deps_expanded[len_args : len_args + n_args])[0]
if repacker is not None
else args_deps_expanded[len_args]
)
len_args += n_args
kwargs = self.kwarg_repacker(args_deps_expanded[len_args:])[0]
args, kwargs = self._repack(*args_deps_expanded)
return self.fn(*args, **kwargs)


Expand All @@ -2000,7 +2004,12 @@ def _map_partitions(
will not be traversed to extract all dask collections, except those in
the first dimension of args or kwargs.
"""
token = token or tokenize(fn, *args, output_divisions, **kwargs)
if isinstance(fn, ArgsKwargsPackedFunction):
token_args, token_kwargs = fn._repack(*args)
token = token or tokenize(fn.fn, *token_args, output_divisions, **token_kwargs)
else:
token = token or tokenize(fn, *args, output_divisions, **kwargs)

label = hyphenize(label or funcname(fn))
name = f"{label}-{token}"
deps = [a for a in args if is_dask_collection(a)] + [
Expand Down

0 comments on commit 49b836f

Please sign in to comment.