Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: go direct to _map_partitions in functions that receive already flat deps #558

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
30 changes: 27 additions & 3 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,6 +2171,22 @@ def map_partitions(
message += f"- {type(arg)}"
raise TypeError(message)

if len(kwargs) == 0: # non-critical FIXME: use kwarg_flat_deps
non_traversed_deps, _ = unpack_collections(*args, traverse=False)
if len(flat_deps) == len(non_traversed_deps) and all(
id(traversed_dep) == id(non_traversed_dep)
for traversed_dep, non_traversed_dep in zip(flat_deps, non_traversed_deps)
):
return _map_partitions(
base_fn,
*args,
label=label,
token=token,
meta=meta,
output_divisions=output_divisions,
**kwargs,
)

arg_flat_deps_expanded = []
arg_repackers = []
arg_lens_for_repackers = []
Expand Down Expand Up @@ -2541,11 +2557,19 @@ def to_length_zero_arrays(objects: Sequence[Any]) -> tuple[Any, ...]:
return tuple(map(length_zero_array_or_identity, objects))


def map_meta(fn: Callable | ArgsKwargsPackedFunction, *deps: Any) -> ak.Array | None:
# NOTE: fn is assumed to be a *packed* function
def map_meta(
fn: Callable | ArgsKwargsPackedFunction, *deps: Any, **kwargs: Any
) -> ak.Array | None:
# NOTE: fn to be a *packed* function (so flat deps or ArgsKwargsPackedFunction)
# if ArgsKwargsPackedFunction we do not allow kwargs
# as defined up in map_partitions. be careful!
if isinstance(fn, ArgsKwargsPackedFunction) and len(kwargs) > 0:
raise ValueError("ArgsKwargsPackedFunctions may not have additional kwargs!")
try:
meta = fn(*to_meta(deps))
if isinstance(fn, ArgsKwargsPackedFunction):
meta = fn(*to_meta(deps))
else:
meta = fn(*to_meta(deps), **kwargs)
return meta
except Exception as err:
# if compute-unknown-meta is False then we don't care about
Expand Down
17 changes: 17 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,10 @@ def test_make_unknown_length():
len(ul_tt1)


def my_power_flat(arg_x, arg_y):
return arg_x**arg_y


def my_power(arg_x, *, kwarg_y=None):
return arg_x**kwarg_y

Expand All @@ -806,6 +810,12 @@ def test_map_partitions_args_and_kwargs_have_collection():
zc = my_power(xc, kwarg_y=yc)
zl = dak.map_partitions(my_power, xl, kwarg_y=yl)

# kwargs that contain collections should be wrapped
if hasattr(zl.dask.layers[zl.name], "task"):
assert isinstance(
zl.dask.layers[zl.name].task.func, dak.lib.core.ArgsKwargsPackedFunction
)

assert_eq(zc, zl)

zd = structured_function(inputs={"x": xc, "y": xc, "z": yc})
Expand All @@ -829,8 +839,14 @@ def test_map_partitions_args_and_kwargs_have_collection():

zg = my_power(xc, kwarg_y=2.0)
zp = dak.map_partitions(my_power, xl, kwarg_y=2.0)
zp_f = dak.map_partitions(my_power_flat, xl, 2.0)

# this invocation of my_power shouldn't be wrapped, no collections
if hasattr(zp_f.dask.layers[zp_f.name], "task"):
assert zp_f.dask.layers[zp_f.name].task.func is my_power_flat

assert_eq(zg, zp)
assert_eq(zg, zp_f)

a = ak.Array(
[
Expand Down Expand Up @@ -860,6 +876,7 @@ def test_map_partitions_args_and_kwargs_have_collection():
ccc=cc,
ddd=dd,
)

assert_eq(res1, res2)


Expand Down
Loading