Skip to content

Commit

Permalink
properly refactor mapfilter and prerun functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Nov 27, 2024
1 parent 3634179 commit 2f3ef20
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 155 deletions.
2 changes: 1 addition & 1 deletion src/dask_awkward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
map_partitions,
partition_compatibility,
)
from dask_awkward.lib.decorator import mapfilter
from dask_awkward.lib.describe import backend, fields
from dask_awkward.lib.inspect import (
report_necessary_buffers,
report_necessary_columns,
sample,
)
from dask_awkward.lib.mapfilter import mapfilter

necessary_columns = report_necessary_columns # Export for backwards compatibility.

Expand Down
2 changes: 1 addition & 1 deletion src/dask_awkward/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
map_partitions,
partition_compatibility,
)
from dask_awkward.lib.decorator import mapfilter
from dask_awkward.lib.describe import backend, fields
from dask_awkward.lib.inspect import (
report_necessary_buffers,
Expand All @@ -28,6 +27,7 @@
from dask_awkward.lib.io.json import from_json, to_json
from dask_awkward.lib.io.parquet import from_parquet, to_parquet
from dask_awkward.lib.io.text import from_text
from dask_awkward.lib.mapfilter import mapfilter
from dask_awkward.lib.operations import concatenate
from dask_awkward.lib.reducers import (
all,
Expand Down
293 changes: 148 additions & 145 deletions src/dask_awkward/lib/decorator.py → src/dask_awkward/lib/mapfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dask.highlevelgraph import HighLevelGraph
from dask.typing import DaskCollection

from dask_awkward.lib.core import Array as DakArray
from dask_awkward.lib.core import (
_map_partitions_prepare,
_to_packed_fn_args,
Expand Down Expand Up @@ -141,6 +142,119 @@ def cmp(left, right):
class UntraceableFunctionError(Exception): ...


def _func_args(fun: tp.Callable, *args: tp.Any, **kwargs: tp.Any) -> tp.Mapping:
import inspect

ba = inspect.signature(fun).bind(*args, **kwargs)
return ba.arguments


def reports2needs(reports: tp.Mapping) -> dict:
import ast
from collections import defaultdict

needs = defaultdict(list)
for arg, report in reports.items():
# this should maybe be differently treated?
keys = set(report.shape_touched) | set(report.data_touched)
for key in keys:
slce = ast.literal_eval(key)
# only strings are actual slice paths to columns,
# `None` or `ints` are path values to non-record array types,
# see: https://github.com/scikit-hep/awkward/pull/3311
slce = tuple(it for it in slce if isinstance(it, str))
needs[arg].append(slce)
return needs


def _replace_arrays_with_typetracers(meta: tp.Any) -> tp.Any:
def _to_tracer(meta: tp.Any) -> tp.Any:
if isinstance(meta, ak.Array):
if not ak.backend(meta) == "typetracer":
meta = typetracer_array(meta)
elif isinstance(meta, DakArray):
meta = to_meta([meta])
return meta

if isinstance(meta, tuple):
meta = tuple(map(_to_tracer, meta))
else:
meta = _to_tracer(meta)
return meta


def prerun(
fun: tp.Callable, *args: tp.Any, **kwargs: tp.Any
) -> tuple[tp.Any, tp.Mapping]:
in_arguments = _func_args(fun, *args, **kwargs)

# replace ak.Arrays with typetracers and store the reports
reports = {}
fun_kwargs = {}
args_metas = {arg: to_meta([val])[0] for arg, val in in_arguments.items()}

# can't typetrace if no ak.Arrays are present
ak_arrays = tuple(filter(lambda x: isinstance(x, ak.Array), args_metas.values()))
if not ak_arrays:
return None, {}

def render_buffer_key(
form: ak.forms.Form,
form_key: str,
attribute: str,
) -> str:
return form_key

# prepare function arguments
for arg, val in args_metas.items():
if isinstance(val, ak.Array):
if not ak.backend(val) == "typetracer":
val = typetracer_array(val)
tracer, report = ak.typetracer.typetracer_with_report(
val.layout.form_with_key_path(root=()),
highlevel=True,
behavior=val.behavior,
attrs=val.attrs,
buffer_key=render_buffer_key,
)
reports[arg] = report
fun_kwargs[arg] = tracer
else:
fun_kwargs[arg] = val

# try to run the function once with type tracers
try:
out = fun(**fun_kwargs)
except Exception as err:
import traceback

# get line number of where the error occurred in the provided function
# traceback 0: this function, 1: the provided function, >1: the rest of the stack
tb = traceback.extract_tb(err.__traceback__)
line_number = tb[1].lineno

# add also the reports of the typetracer to the error message,
# and format them as 'needs' wants it to be
needs = dict(reports2needs(reports=reports))

msg = (
f"This wrapped function '{fun}' is not traceable. "
f"An error occurred at line {line_number}.\n"
"'mapfilter' can circumvent this by providing the 'needs' and "
"'meta' arguments to the decorator.\n"
"\n- 'needs': mapping where the keys point to input argument "
"dask_awkward arrays and the values to columns/slices that "
"should be touched explicitly. The typetracing step could "
"determine the following necessary columns/slices.\n\n"
f"Typetracer reported the following 'needs':\n"
f"{needs}\n"
"\n- 'meta': value(s) of what the wrapped function would "
"return. For arrays, only the shape and type matter."
)
raise UntraceableFunctionError(msg) from err
return out, reports


@dataclass
class mapfilter:
"""Map a callable across all partitions of any number of collections.
Expand Down Expand Up @@ -182,13 +296,6 @@ class mapfilter:
touch additional objects **explicitly** to get the correct typetracer report.
For this, provide a dictionary that maps input argument that's an array to
the columns/slice of that array that should be touched.
pre_run: bool
Endable/disable the pre-run of the function to get the metadata. This is useful
for cases where heavy computations (e.g. machine learning algorithms) are performed
in the decorated function. If ``True``, the function will be run first with a typetracer
to automatically infer the needed columns and return values. If ``False``, the function
will be run as is on the dask-worker; it is required to provide ``needs`` and ``meta``
in this case. Default is ``True``.
"""

base_fn: tp.Callable
Expand All @@ -198,7 +305,6 @@ class mapfilter:
traverse: bool = True
# additional options that are not available in dak.map_partitions
needs: tp.Mapping | None = None
pre_run: bool = True

def __post_init__(self) -> None:
if self.needs is not None and not isinstance(self.needs, tp.Mapping):
Expand All @@ -215,120 +321,9 @@ def __post_init__(self) -> None:
)
raise ValueError(msg)

def in_args(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Mapping:
import inspect

ba = inspect.signature(self.base_fn).bind(*args, **kwargs)
return ba.arguments

def _pre_run(
self,
*args: tp.Any,
**kwargs: tp.Any,
) -> tuple[tp.Any, tp.Mapping]:
in_arguments = self.in_args(*args, **kwargs)

# replace ak.Arrays with typetracers and store the reports
reports = {}
fun_kwargs = {}
args_metas = {arg: to_meta([val])[0] for arg, val in in_arguments.items()}

# can't typetrace if no ak.Arrays are present
ak_arrays = tuple(
filter(lambda x: isinstance(x, ak.Array), args_metas.values())
)
if not ak_arrays:
return None, {}

def render_buffer_key(
form: ak.forms.Form,
form_key: str,
attribute: str,
) -> str:
return form_key

for arg, val in args_metas.items():
if isinstance(val, ak.Array):
if not ak.backend(val) == "typetracer":
val = typetracer_array(val)
# format key?
tracer, report = ak.typetracer.typetracer_with_report(
val.layout.form_with_key_path(root=()),
highlevel=True,
behavior=val.behavior,
attrs=val.attrs,
buffer_key=render_buffer_key,
)
reports[arg] = report
fun_kwargs[arg] = tracer
else:
fun_kwargs[arg] = val
# try to run the function once with type tracers
try:
out = self.base_fn(**fun_kwargs)
except Exception as err:
import traceback

# get line number of where the error occurred
tb = traceback.extract_tb(err.__traceback__)
line_number = tb[-1].lineno

# add also the reports of the typetracer to the error message,
# and format them as 'needs' wants it to be
needs = self.reports2needs(reports=reports)

msg = (
f"This wrapped function '{self.base_fn}' is not traceable. "
f"An error occurred at line {line_number}.\n"
"'mapfilter' can circumvent this by providing the 'needs' and "
"'meta' arguments to the decorator.\n"
"\n- 'needs': mapping where the keys point to input argument "
"dask_awkward arrays and the values to columns/slices that "
"should be touched explicitly. The typetracing step could "
"determine the following necessary columns/slices.\n\n"
f"Typetracer reported the following 'needs':\n"
f"{dict(needs)}\n"
"\n- 'meta': value(s) of what the wrapped function would "
"return. For arrays, only the shape and type matter."
)
raise UntraceableFunctionError(msg) from err
return out, reports

@staticmethod
def reports2needs(reports: tp.Mapping) -> dict:
import ast
from collections import defaultdict

needs = defaultdict(list)
for arg, report in reports.items():
# this should maybe be differently treated?
keys = set(report.shape_touched) | set(report.data_touched)
for key in keys:
slce = ast.literal_eval(key)
# only strings are actual slice paths to columns,
# `None` or `ints` are path values to non-record array types,
# see: https://github.com/scikit-hep/awkward/pull/3311
slce = tuple(it for it in slce if isinstance(it, str))
needs[arg].append(slce)
return needs

@staticmethod
def replace_arrays_with_typetracers(meta: tp.Any) -> tp.Any:
def _to_tracer(meta: tp.Any) -> tp.Any:
if isinstance(meta, ak.Array):
if not ak.backend(meta) == "typetracer":
meta = typetracer_array(meta)
return meta

if isinstance(meta, tuple):
meta = tuple(map(_to_tracer, meta))
else:
meta = _to_tracer(meta)
return meta

def wrapped_fn(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
in_arguments = _func_args(self.base_fn, *args, **kwargs)
if self.needs is not None:
in_arguments = self.in_args(*args, **kwargs)
tobe_touched = set()
for arg in self.needs.keys():
if arg in in_arguments:
Expand All @@ -346,47 +341,55 @@ def wrapped_fn(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
# touch the objects explicitly
for slce in self.needs[arg]:
ak.typetracer.touch_data(array[slce])

if self.meta is not None:
ak_arrays = [
arg for arg in in_arguments.values() if isinstance(arg, ak.Array)
]
if all(ak.backend(arr) == "typetracer" for arr in ak_arrays):
# if the meta is known, we can use it to skip the tracing step
return self.meta
return self.base_fn(*args, **kwargs)

def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
if self.pre_run:
# we can actually return `needs` and `meta` here
# and circumvent the second tracing step
meta, reports = self._pre_run(*args, **kwargs)

# compare meta(s)
if self.meta:
_compare_return_vals(meta, self.meta)

# if self.needs is given, extend the the reports
# with all additionally given columns/slices
needs = self.reports2needs(reports=reports)
if self.needs is not None:
for arg, slces in self.needs.items():
if r_slces := set(needs[arg]):
needs[arg] = list(r_slces | set(slces))
else:
needs[arg].extend(slces)

fn, arg_flat_deps_expanded, kwarg_flat_deps = _to_packed_fn_args(
self.wrapped_fn, *args, traverse=self.traverse, **kwargs
)

arg_flat_deps_expanded = _replace_arrays_with_typetracers(
arg_flat_deps_expanded
)
kwarg_flat_deps = _replace_arrays_with_typetracers(kwarg_flat_deps)
meta = _replace_arrays_with_typetracers(self.meta)
in_typetracing_mode = arg_flat_deps_expanded or kwarg_flat_deps or meta

try:
hlg, meta, deps, name = _map_partitions_prepare(
fn,
*arg_flat_deps_expanded,
*kwarg_flat_deps,
label=self.label,
token=self.token,
meta=self.replace_arrays_with_typetracers(self.meta),
meta=meta,
output_divisions=None,
)
except Exception as err:
if not self.pre_run:
# put message here that it might help to do a pre-run
pass
raise err from None
# if there's a problem with typetracing, we can report it and recommend a 'prerun'
if in_typetracing_mode:
fn_args = _func_args(self.base_fn, *args, **kwargs)
sig_str = ", ".join(f"{k}={v}" for k, v in fn_args.items())
msg = (
f"Failed to trace the function '{self.base_fn}'. "
"You can use 'needs' and 'meta' to circumvent this step. "
"For this, it might be helpful to do a pre-run of the function:"
f"\n\n\tfrom dask_awkward.lib.mapfilter import prerun"
f"\n\n\tprerun({self.base_fn.__name__}, {sig_str})"
f"\n\nThis may help to infer the correct `needs` for `mapfilter`."
)
raise UntraceableFunctionError(msg) from err
# otherwise, just raise the error - whatever it is
else:
raise err from None

# check consistent partitioning
if len(deps) == 0:
Expand All @@ -405,6 +408,6 @@ def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
return _multi_return_map_partitions(
hlg=hlg,
name=name,
meta=self.replace_arrays_with_typetracers(meta),
meta=meta,
npartitions=npart,
)
Loading

0 comments on commit 2f3ef20

Please sign in to comment.