From 3634179f4569305cd7ff65ffeea6f2e29ec1560f Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Tue, 26 Nov 2024 13:11:06 -0500 Subject: [PATCH] wip --- src/dask_awkward/lib/decorator.py | 301 ++++++++++++++++++------------ tests/test_decorator.py | 26 ++- 2 files changed, 203 insertions(+), 124 deletions(-) diff --git a/src/dask_awkward/lib/decorator.py b/src/dask_awkward/lib/decorator.py index 4f239146..c4866910 100644 --- a/src/dask_awkward/lib/decorator.py +++ b/src/dask_awkward/lib/decorator.py @@ -14,6 +14,8 @@ empty_typetracer, new_array_object, partitionwise_layer, + to_meta, + typetracer_array, ) @@ -52,7 +54,7 @@ def _single_return_map_partitions( ): msg = ( f"{meta=} is not (yet) supported as return type. If possible, " - "you can convert it to ak.Array first, or wrap it with a python container." + "you can convert it to ak.Array, or wrap it with a python container." ) raise NotImplementedError(msg) # don't know? -> put it in a bag @@ -114,6 +116,31 @@ def _multi_return_map_partitions( return tuple(ret) +def _compare_return_vals(left: tp.Any, right: tp.Any) -> None: + def cmp(left, right): + msg = ( + "The provided 'meta' does not match " + "the output type inferred from the pre-run step; " + "got {right}, but expected {left}.".format(left=left, right=right) + ) + if isinstance(left, ak.Array): + if left.layout.form != right.layout.form: + raise ValueError(msg) + + else: + if left != right: + raise ValueError(msg) + + if isinstance(left, tuple) and isinstance(right, tuple): + for left_, right_ in zip(left, right): + cmp(left_, right_) + else: + cmp(left, right) + + +class UntraceableFunctionError(Exception): ... + + @dataclass class mapfilter: """Map a callable across all partitions of any number of collections. @@ -155,77 +182,13 @@ class mapfilter: touch additional objects **explicitly** to get the correct typetracer report. For this, provide a dictionary that maps input argument that's an array to the columns/slice of that array that should be touched. - out_like: tp.Any, optional - If ``None`` (the default), the output will be computed through the default - typetracing pass. If a ak.Array is provided, the output will be mocked for the typetracing - pass as the provided array. This is useful for cases where the output can not be - computed through the default typetracing pass. - - - Returns - ------- - dask_awkward.Array - The new collection. - - Examples - -------- - >>> from coffea.nanoevents import NanoEventsFactory - >>> from coffea.processor.decorator import mapfilter - >>> events, report = NanoEventsFactory.from_root( - {"https://github.com/CoffeaTeam/coffea/raw/master/tests/samples/nano_dy.root": "Events"}, - metadata={"dataset": "Test"}, - uproot_options={"allow_read_errors_with_report": True}, - steps_per_file=2, - ).events() - >>> @mapfilter - def process(events): - # do an emberassing parallel computation - # only eager awkward is allowed here - import awkward as ak - - jets = events.Jet - jets = jets[(jets.pt > 30) & (abs(jets.eta) < 2.4)] - return events[ak.num(jets) == 2] - >>> selected = process(events) - >>> print(process(events).dask) # collapsed into a single node (2.) - HighLevelGraph with 3 layers. - - 0. from-uproot-0e54dc3659a3c020608e28b03f22b0f4 - 1. from-uproot-971b7f00ce02a189422528a5044b08fb - 2. >> print(process.base_fn(events).dask) # call the function as it is (many nodes in the graph); `base_fn` is the function that is wrapped - HighLevelGraph with 13 layers. - - 0. from-uproot-0e54dc3659a3c020608e28b03f22b0f4 - 1. from-uproot-971b7f00ce02a189422528a5044b08fb - 2. Jet-efead9353042e606d7ffd59845f4675d - 3. eta-f31547c2a94efc053977790a489779be - 4. absolute-74ced100c5db654eb0edd810542f724a - 5. less-b33e652814e0cd5157b3c0885087edcb - 6. pt-f50c15fa409e60152de61957d2a4a0d8 - 7. greater-da496609d36631ac857bb15eba6f0ba6 - 8. bitwise-and-a501c0ff0f5bcab618514603d4f78eec - 9. getitem-fc20cad0c32130756d447fc749654d11 - 10. >> from functools import partial - >>> @partial(mapfilter, needs={"events": [("Electron", "pt"), ("Electron", "eta")]}) - def process(events): - # do an emberassing parallel computation - # only eager awkward is allowed here - import awkward as ak - - jets = events.Jet - jets = jets[(jets.pt > 30) & (abs(jets.eta) < 2.4)] - return events[ak.num(jets) == 2] - >>> selected = process(events) - >>> print(dak.necessary_columns(selected)) - {'from-uproot-0e54dc3659a3c020608e28b03f22b0f4': frozenset({'Electron_eta', 'Jet_eta', 'nElectron', 'Jet_pt', 'Electron_pt', 'nJet'})} - + pre_run: bool + Endable/disable the pre-run of the function to get the metadata. This is useful + for cases where heavy computations (e.g. machine learning algorithms) are performed + in the decorated function. If ``True``, the function will be run first with a typetracer + to automatically infer the needed columns and return values. If ``False``, the function + will be run as is on the dask-worker; it is required to provide ``needs`` and ``meta`` + in this case. Default is ``True``. """ base_fn: tp.Callable @@ -235,7 +198,7 @@ def process(events): traverse: bool = True # additional options that are not available in dak.map_partitions needs: tp.Mapping | None = None - out_like: tp.Any = None + pre_run: bool = True def __post_init__(self) -> None: if self.needs is not None and not isinstance(self.needs, tp.Mapping): @@ -252,12 +215,120 @@ def __post_init__(self) -> None: ) raise ValueError(msg) - def wrapped_fn(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: + def in_args(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Mapping: import inspect ba = inspect.signature(self.base_fn).bind(*args, **kwargs) - in_arguments = ba.arguments + return ba.arguments + + def _pre_run( + self, + *args: tp.Any, + **kwargs: tp.Any, + ) -> tuple[tp.Any, tp.Mapping]: + in_arguments = self.in_args(*args, **kwargs) + + # replace ak.Arrays with typetracers and store the reports + reports = {} + fun_kwargs = {} + args_metas = {arg: to_meta([val])[0] for arg, val in in_arguments.items()} + + # can't typetrace if no ak.Arrays are present + ak_arrays = tuple( + filter(lambda x: isinstance(x, ak.Array), args_metas.values()) + ) + if not ak_arrays: + return None, {} + + def render_buffer_key( + form: ak.forms.Form, + form_key: str, + attribute: str, + ) -> str: + return form_key + + for arg, val in args_metas.items(): + if isinstance(val, ak.Array): + if not ak.backend(val) == "typetracer": + val = typetracer_array(val) + # format key? + tracer, report = ak.typetracer.typetracer_with_report( + val.layout.form_with_key_path(root=()), + highlevel=True, + behavior=val.behavior, + attrs=val.attrs, + buffer_key=render_buffer_key, + ) + reports[arg] = report + fun_kwargs[arg] = tracer + else: + fun_kwargs[arg] = val + # try to run the function once with type tracers + try: + out = self.base_fn(**fun_kwargs) + except Exception as err: + import traceback + + # get line number of where the error occurred + tb = traceback.extract_tb(err.__traceback__) + line_number = tb[-1].lineno + + # add also the reports of the typetracer to the error message, + # and format them as 'needs' wants it to be + needs = self.reports2needs(reports=reports) + + msg = ( + f"This wrapped function '{self.base_fn}' is not traceable. " + f"An error occurred at line {line_number}.\n" + "'mapfilter' can circumvent this by providing the 'needs' and " + "'meta' arguments to the decorator.\n" + "\n- 'needs': mapping where the keys point to input argument " + "dask_awkward arrays and the values to columns/slices that " + "should be touched explicitly. The typetracing step could " + "determine the following necessary columns/slices.\n\n" + f"Typetracer reported the following 'needs':\n" + f"{dict(needs)}\n" + "\n- 'meta': value(s) of what the wrapped function would " + "return. For arrays, only the shape and type matter." + ) + raise UntraceableFunctionError(msg) from err + return out, reports + + @staticmethod + def reports2needs(reports: tp.Mapping) -> dict: + import ast + from collections import defaultdict + + needs = defaultdict(list) + for arg, report in reports.items(): + # this should maybe be differently treated? + keys = set(report.shape_touched) | set(report.data_touched) + for key in keys: + slce = ast.literal_eval(key) + # only strings are actual slice paths to columns, + # `None` or `ints` are path values to non-record array types, + # see: https://github.com/scikit-hep/awkward/pull/3311 + slce = tuple(it for it in slce if isinstance(it, str)) + needs[arg].append(slce) + return needs + + @staticmethod + def replace_arrays_with_typetracers(meta: tp.Any) -> tp.Any: + def _to_tracer(meta: tp.Any) -> tp.Any: + if isinstance(meta, ak.Array): + if not ak.backend(meta) == "typetracer": + meta = typetracer_array(meta) + return meta + + if isinstance(meta, tuple): + meta = tuple(map(_to_tracer, meta)) + else: + meta = _to_tracer(meta) + return meta + + def wrapped_fn(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: if self.needs is not None: + in_arguments = self.in_args(*args, **kwargs) tobe_touched = set() for arg in self.needs.keys(): if arg in in_arguments: @@ -275,55 +346,49 @@ def wrapped_fn(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: # touch the objects explicitly for slce in self.needs[arg]: ak.typetracer.touch_data(array[slce]) - if self.out_like is not None: - # check if we're in the typetracing step - if any( - ak.backend(array) == "typetracer" for array in in_arguments.values() - ): - # mock the output as the specified type - if isinstance(self.out_like, (tuple, list)): - output = [] - for out in self.out_like: - if isinstance(out, ak.Array): - if not ak.backend(out) == "typetracer": - out = ak.Array( - out.layout.to_typetracer(forget_length=True) - ) - output.append(out) - else: - output.append(out) - return tuple(output) - else: - if isinstance(self.out_like, ak.Array): - if not ak.backend(self.out_like) == "typetracer": - return ak.Array( - self.out_like.layout.to_typetracer(forget_length=True) - ) - return self.out_like - else: - raise ValueError( - "out_like must be an awkward array in the single return value case." - ) return self.base_fn(*args, **kwargs) def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: + if self.pre_run: + # we can actually return `needs` and `meta` here + # and circumvent the second tracing step + meta, reports = self._pre_run(*args, **kwargs) + + # compare meta(s) + if self.meta: + _compare_return_vals(meta, self.meta) + + # if self.needs is given, extend the the reports + # with all additionally given columns/slices + needs = self.reports2needs(reports=reports) + if self.needs is not None: + for arg, slces in self.needs.items(): + if r_slces := set(needs[arg]): + needs[arg] = list(r_slces | set(slces)) + else: + needs[arg].extend(slces) + fn, arg_flat_deps_expanded, kwarg_flat_deps = _to_packed_fn_args( self.wrapped_fn, *args, traverse=self.traverse, **kwargs ) - hlg, meta, deps, name = _map_partitions_prepare( - fn, - *arg_flat_deps_expanded, - *kwarg_flat_deps, - label=self.label, - token=self.token, - meta=self.meta, - output_divisions=None, - ) + try: + hlg, meta, deps, name = _map_partitions_prepare( + fn, + *arg_flat_deps_expanded, + *kwarg_flat_deps, + label=self.label, + token=self.token, + meta=self.replace_arrays_with_typetracers(self.meta), + output_divisions=None, + ) + except Exception as err: + if not self.pre_run: + # put message here that it might help to do a pre-run + pass + raise err from None # check consistent partitioning - # needs to be implemented - # how to get the (correct) partitioning from the deps (any dask collection)? if len(deps) == 0: raise ValueError("Need at least one input that is a dask collection.") elif len(deps) == 1: @@ -331,15 +396,15 @@ def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: else: npart = deps[0].npartitions if not all(dep.npartitions == npart for dep in deps): - msg = "All inputs must have the same partitioning, got:" + msg = "All inputs must have the same number of partitions, got:" for dep in deps: npartitions = dep.npartitions - msg += f"\n{dep} = {npartitions=}" + msg += f"\n{dep}: {npartitions=}" raise ValueError(msg) return _multi_return_map_partitions( hlg=hlg, name=name, - meta=meta, + meta=self.replace_arrays_with_typetracers(meta), npartitions=npart, ) diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 0a30e2e9..9a205a5b 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -47,13 +47,21 @@ def fun(x): def test_mapfilter_needs_outlike(): - ak_array = ak.zip({"pt": [10, 20, 30, 40], "eta": [1, 1, 1, 1]}) + ak_array = ak.zip( + { + "x": [{"foo": [10, 20, 30, 40], "bar": [10, 20, 30, 40]}], + "y": [{"foo": [1, 1, 1, 1], "bar": [1, 1, 1, 1]}], + "z": [0, 0, 0, 0], + } + ) dak_array = dak.from_awkward(ak_array, 2) def untraceable_fun(muons): # a non-traceable computation for ak.typetracer # which needs "pt" column from muons and returns a 1-element array - pt = ak.to_numpy(muons.pt) + muons.y.bar[...] + muons.z[...] + pt = ak.to_numpy(muons.x.foo) return ak.Array([np.sum(pt)]) # first check that the function is not traceable @@ -61,7 +69,13 @@ def untraceable_fun(muons): dak.map_partitions(untraceable_fun, dak_array) # now check that the necessary columns are reported correctly - wrap = partial(dak.mapfilter, needs={"muons": ["pt"]}, out_like=ak.Array([0.0])) - out = wrap(untraceable_fun)(dak_array) - cols = next(iter(dak.report_necessary_columns(out).values())) - assert cols == {"pt"} + wrap = partial( + dak.mapfilter, + needs={"muons": [("x", "foo"), ("z",), ("y", "bar")]}, + meta=ak.Array([0.0]), + pre_run=False, + ) + out = wrap(untraceable_fun)(dak_array) # noqa + # TODO + # cols = next(iter(dak.report_necessary_columns(out).values())) + # assert cols == {"pt"}