Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Nov 26, 2024
1 parent 836b24b commit 3634179
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 124 deletions.
301 changes: 183 additions & 118 deletions src/dask_awkward/lib/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
empty_typetracer,
new_array_object,
partitionwise_layer,
to_meta,
typetracer_array,
)


Expand Down Expand Up @@ -52,7 +54,7 @@ def _single_return_map_partitions(
):
msg = (
f"{meta=} is not (yet) supported as return type. If possible, "
"you can convert it to ak.Array first, or wrap it with a python container."
"you can convert it to ak.Array, or wrap it with a python container."
)
raise NotImplementedError(msg)
# don't know? -> put it in a bag
Expand Down Expand Up @@ -114,6 +116,31 @@ def _multi_return_map_partitions(
return tuple(ret)


def _compare_return_vals(left: tp.Any, right: tp.Any) -> None:
def cmp(left, right):
msg = (
"The provided 'meta' does not match "
"the output type inferred from the pre-run step; "
"got {right}, but expected {left}.".format(left=left, right=right)
)
if isinstance(left, ak.Array):
if left.layout.form != right.layout.form:
raise ValueError(msg)

else:
if left != right:
raise ValueError(msg)

if isinstance(left, tuple) and isinstance(right, tuple):
for left_, right_ in zip(left, right):
cmp(left_, right_)
else:
cmp(left, right)


class UntraceableFunctionError(Exception): ...


@dataclass
class mapfilter:
"""Map a callable across all partitions of any number of collections.
Expand Down Expand Up @@ -155,77 +182,13 @@ class mapfilter:
touch additional objects **explicitly** to get the correct typetracer report.
For this, provide a dictionary that maps input argument that's an array to
the columns/slice of that array that should be touched.
out_like: tp.Any, optional
If ``None`` (the default), the output will be computed through the default
typetracing pass. If a ak.Array is provided, the output will be mocked for the typetracing
pass as the provided array. This is useful for cases where the output can not be
computed through the default typetracing pass.
Returns
-------
dask_awkward.Array
The new collection.
Examples
--------
>>> from coffea.nanoevents import NanoEventsFactory
>>> from coffea.processor.decorator import mapfilter
>>> events, report = NanoEventsFactory.from_root(
{"https://github.com/CoffeaTeam/coffea/raw/master/tests/samples/nano_dy.root": "Events"},
metadata={"dataset": "Test"},
uproot_options={"allow_read_errors_with_report": True},
steps_per_file=2,
).events()
>>> @mapfilter
def process(events):
# do an emberassing parallel computation
# only eager awkward is allowed here
import awkward as ak
jets = events.Jet
jets = jets[(jets.pt > 30) & (abs(jets.eta) < 2.4)]
return events[ak.num(jets) == 2]
>>> selected = process(events)
>>> print(process(events).dask) # collapsed into a single node (2.)
HighLevelGraph with 3 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x11700d640>
0. from-uproot-0e54dc3659a3c020608e28b03f22b0f4
1. from-uproot-971b7f00ce02a189422528a5044b08fb
2. <dask-awkward.lib.core.ArgsKwargsPackedFunction ob-c9ee010d2e5671a2805f6d5106040d55
>>> print(process.base_fn(events).dask) # call the function as it is (many nodes in the graph); `base_fn` is the function that is wrapped
HighLevelGraph with 13 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x136e3d910>
0. from-uproot-0e54dc3659a3c020608e28b03f22b0f4
1. from-uproot-971b7f00ce02a189422528a5044b08fb
2. Jet-efead9353042e606d7ffd59845f4675d
3. eta-f31547c2a94efc053977790a489779be
4. absolute-74ced100c5db654eb0edd810542f724a
5. less-b33e652814e0cd5157b3c0885087edcb
6. pt-f50c15fa409e60152de61957d2a4a0d8
7. greater-da496609d36631ac857bb15eba6f0ba6
8. bitwise-and-a501c0ff0f5bcab618514603d4f78eec
9. getitem-fc20cad0c32130756d447fc749654d11
10. <dask-awkward.lib.core.ArgsKwargsPackedFunction ob-0d3090f1c746eafd595782bcacd30d69
11. equal-a4642445fb4e5da0b852c2813966568a
12. getitem-f951afb4c4d4b527553f5520f6765e43
# if you want to touch additional objects explicitly, because they are not touched by the standard typetracer (i.e. due to 'opaque' operations)
# you can provide a dict of slices that should be touched directly to the decorator, e.g.:
>>> from functools import partial
>>> @partial(mapfilter, needs={"events": [("Electron", "pt"), ("Electron", "eta")]})
def process(events):
# do an emberassing parallel computation
# only eager awkward is allowed here
import awkward as ak
jets = events.Jet
jets = jets[(jets.pt > 30) & (abs(jets.eta) < 2.4)]
return events[ak.num(jets) == 2]
>>> selected = process(events)
>>> print(dak.necessary_columns(selected))
{'from-uproot-0e54dc3659a3c020608e28b03f22b0f4': frozenset({'Electron_eta', 'Jet_eta', 'nElectron', 'Jet_pt', 'Electron_pt', 'nJet'})}
pre_run: bool
Endable/disable the pre-run of the function to get the metadata. This is useful
for cases where heavy computations (e.g. machine learning algorithms) are performed
in the decorated function. If ``True``, the function will be run first with a typetracer
to automatically infer the needed columns and return values. If ``False``, the function
will be run as is on the dask-worker; it is required to provide ``needs`` and ``meta``
in this case. Default is ``True``.
"""

base_fn: tp.Callable
Expand All @@ -235,7 +198,7 @@ def process(events):
traverse: bool = True
# additional options that are not available in dak.map_partitions
needs: tp.Mapping | None = None
out_like: tp.Any = None
pre_run: bool = True

def __post_init__(self) -> None:
if self.needs is not None and not isinstance(self.needs, tp.Mapping):
Expand All @@ -252,12 +215,120 @@ def __post_init__(self) -> None:
)
raise ValueError(msg)

def wrapped_fn(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
def in_args(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Mapping:
import inspect

ba = inspect.signature(self.base_fn).bind(*args, **kwargs)
in_arguments = ba.arguments
return ba.arguments

def _pre_run(
self,
*args: tp.Any,
**kwargs: tp.Any,
) -> tuple[tp.Any, tp.Mapping]:
in_arguments = self.in_args(*args, **kwargs)

# replace ak.Arrays with typetracers and store the reports
reports = {}
fun_kwargs = {}
args_metas = {arg: to_meta([val])[0] for arg, val in in_arguments.items()}

# can't typetrace if no ak.Arrays are present
ak_arrays = tuple(
filter(lambda x: isinstance(x, ak.Array), args_metas.values())
)
if not ak_arrays:
return None, {}

def render_buffer_key(
form: ak.forms.Form,
form_key: str,
attribute: str,
) -> str:
return form_key

for arg, val in args_metas.items():
if isinstance(val, ak.Array):
if not ak.backend(val) == "typetracer":
val = typetracer_array(val)
# format key?
tracer, report = ak.typetracer.typetracer_with_report(
val.layout.form_with_key_path(root=()),
highlevel=True,
behavior=val.behavior,
attrs=val.attrs,
buffer_key=render_buffer_key,
)
reports[arg] = report
fun_kwargs[arg] = tracer
else:
fun_kwargs[arg] = val
# try to run the function once with type tracers
try:
out = self.base_fn(**fun_kwargs)
except Exception as err:
import traceback

# get line number of where the error occurred
tb = traceback.extract_tb(err.__traceback__)
line_number = tb[-1].lineno

# add also the reports of the typetracer to the error message,
# and format them as 'needs' wants it to be
needs = self.reports2needs(reports=reports)

msg = (
f"This wrapped function '{self.base_fn}' is not traceable. "
f"An error occurred at line {line_number}.\n"
"'mapfilter' can circumvent this by providing the 'needs' and "
"'meta' arguments to the decorator.\n"
"\n- 'needs': mapping where the keys point to input argument "
"dask_awkward arrays and the values to columns/slices that "
"should be touched explicitly. The typetracing step could "
"determine the following necessary columns/slices.\n\n"
f"Typetracer reported the following 'needs':\n"
f"{dict(needs)}\n"
"\n- 'meta': value(s) of what the wrapped function would "
"return. For arrays, only the shape and type matter."
)
raise UntraceableFunctionError(msg) from err
return out, reports

@staticmethod
def reports2needs(reports: tp.Mapping) -> dict:
import ast
from collections import defaultdict

needs = defaultdict(list)
for arg, report in reports.items():
# this should maybe be differently treated?
keys = set(report.shape_touched) | set(report.data_touched)
for key in keys:
slce = ast.literal_eval(key)
# only strings are actual slice paths to columns,
# `None` or `ints` are path values to non-record array types,
# see: https://github.com/scikit-hep/awkward/pull/3311
slce = tuple(it for it in slce if isinstance(it, str))
needs[arg].append(slce)
return needs

@staticmethod
def replace_arrays_with_typetracers(meta: tp.Any) -> tp.Any:
def _to_tracer(meta: tp.Any) -> tp.Any:
if isinstance(meta, ak.Array):
if not ak.backend(meta) == "typetracer":
meta = typetracer_array(meta)
return meta

if isinstance(meta, tuple):
meta = tuple(map(_to_tracer, meta))
else:
meta = _to_tracer(meta)
return meta

def wrapped_fn(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
if self.needs is not None:
in_arguments = self.in_args(*args, **kwargs)
tobe_touched = set()
for arg in self.needs.keys():
if arg in in_arguments:
Expand All @@ -275,71 +346,65 @@ def wrapped_fn(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
# touch the objects explicitly
for slce in self.needs[arg]:
ak.typetracer.touch_data(array[slce])
if self.out_like is not None:
# check if we're in the typetracing step
if any(
ak.backend(array) == "typetracer" for array in in_arguments.values()
):
# mock the output as the specified type
if isinstance(self.out_like, (tuple, list)):
output = []
for out in self.out_like:
if isinstance(out, ak.Array):
if not ak.backend(out) == "typetracer":
out = ak.Array(
out.layout.to_typetracer(forget_length=True)
)
output.append(out)
else:
output.append(out)
return tuple(output)
else:
if isinstance(self.out_like, ak.Array):
if not ak.backend(self.out_like) == "typetracer":
return ak.Array(
self.out_like.layout.to_typetracer(forget_length=True)
)
return self.out_like
else:
raise ValueError(
"out_like must be an awkward array in the single return value case."
)
return self.base_fn(*args, **kwargs)

def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
if self.pre_run:
# we can actually return `needs` and `meta` here
# and circumvent the second tracing step
meta, reports = self._pre_run(*args, **kwargs)

# compare meta(s)
if self.meta:
_compare_return_vals(meta, self.meta)

# if self.needs is given, extend the the reports
# with all additionally given columns/slices
needs = self.reports2needs(reports=reports)
if self.needs is not None:
for arg, slces in self.needs.items():
if r_slces := set(needs[arg]):
needs[arg] = list(r_slces | set(slces))
else:
needs[arg].extend(slces)

fn, arg_flat_deps_expanded, kwarg_flat_deps = _to_packed_fn_args(
self.wrapped_fn, *args, traverse=self.traverse, **kwargs
)

hlg, meta, deps, name = _map_partitions_prepare(
fn,
*arg_flat_deps_expanded,
*kwarg_flat_deps,
label=self.label,
token=self.token,
meta=self.meta,
output_divisions=None,
)
try:
hlg, meta, deps, name = _map_partitions_prepare(
fn,
*arg_flat_deps_expanded,
*kwarg_flat_deps,
label=self.label,
token=self.token,
meta=self.replace_arrays_with_typetracers(self.meta),
output_divisions=None,
)
except Exception as err:
if not self.pre_run:
# put message here that it might help to do a pre-run
pass
raise err from None

# check consistent partitioning
# needs to be implemented
# how to get the (correct) partitioning from the deps (any dask collection)?
if len(deps) == 0:
raise ValueError("Need at least one input that is a dask collection.")
elif len(deps) == 1:
npart = deps[0].npartitions
else:
npart = deps[0].npartitions
if not all(dep.npartitions == npart for dep in deps):
msg = "All inputs must have the same partitioning, got:"
msg = "All inputs must have the same number of partitions, got:"
for dep in deps:
npartitions = dep.npartitions
msg += f"\n{dep} = {npartitions=}"
msg += f"\n{dep}: {npartitions=}"
raise ValueError(msg)

return _multi_return_map_partitions(
hlg=hlg,
name=name,
meta=meta,
meta=self.replace_arrays_with_typetracers(meta),
npartitions=npart,
)
Loading

0 comments on commit 3634179

Please sign in to comment.