diff --git a/src/uproot/_dask.py b/src/uproot/_dask.py index c1c0100be..ea4390a32 100644 --- a/src/uproot/_dask.py +++ b/src/uproot/_dask.py @@ -23,7 +23,6 @@ from uproot.behaviors.TBranch import HasBranches, TBranch, _regularize_step_size if TYPE_CHECKING: - from awkward._nplikes.typetracer import TypeTracerReport from awkward.forms import Form from awkward.highlevel import Array as AwkArray @@ -882,14 +881,16 @@ def impl(form, column_path): buffer_key: Final[str] = "{form_key}-{attribute}" def parse_buffer_key(self, buffer_key: str) -> tuple[str, str]: - form_key, attribute = buffer_key.rsplit("-", maxsplit=1) + form_key, *attribute = buffer_key.rsplit("-", maxsplit=1) return form_key, attribute def keys_for_buffer_keys(self, buffer_keys: frozenset[str]) -> frozenset[str]: keys: set[str] = set() for buffer_key in buffer_keys: # Identify form key - form_key, attribute = buffer_key.rsplit("-", maxsplit=1) + form_key, attribute = buffer_key.replace("@.", ".").rsplit( + "-", maxsplit=1 + ) # Identify key from form_key keys.add(self._form_key_to_key[form_key]) return frozenset(keys) @@ -959,6 +960,10 @@ class UprootReadMixin: interp_options: dict[str, Any] allow_read_errors_with_report: bool | tuple[type[BaseException], ...] + @property + def behavior(self): + return self.form_mapping_info.behavior + @property def allowed_exceptions(self): if isinstance(self.allow_read_errors_with_report, tuple): @@ -1026,84 +1031,19 @@ def read_tree( assert tree.source # we must be reading something here return out, tree.source.performance_counters - def mock(self) -> AwkArray: - awkward = uproot.extras.awkward() - return awkward.typetracer.typetracer_from_form( - self.expected_form, - highlevel=True, - behavior=self.form_mapping_info.behavior, - ) - - def mock_empty(self, backend) -> AwkArray: - awkward = uproot.extras.awkward() - return awkward.to_backend( - self.expected_form.length_zero_array(highlevel=False), - backend=backend, - highlevel=True, - behavior=self.form_mapping_info.behavior, - ) - - def prepare_for_projection(self) -> tuple[AwkArray, TypeTracerReport, dict]: - awkward = uproot.extras.awkward() - dask_awkward = uproot.extras.dask_awkward() - - # A form mapping will (may) remap the base form into a new form - # The remapped form can be queried for structural information - - # Build typetracer and associated report object - meta, report = awkward.typetracer.typetracer_with_report( - self.expected_form, - highlevel=True, - behavior=self.form_mapping_info.behavior, - buffer_key=self.form_mapping_info.buffer_key, - ) + @property + def form(self): + return self.expected_form - return ( - meta, - report, - { - "trace": dask_awkward.lib.utils.trace_form_structure( - self.expected_form, - buffer_key=self.form_mapping_info.buffer_key, - ), - "form_info": self.form_mapping_info, - }, - ) + def project(self, columns) -> T: + from dask_awkward.lib.utils import _buf_to_col - def project(self: T, *, report: TypeTracerReport, state: dict) -> T: - keys = self.necessary_columns(report=report, state=state) + keys = [_buf_to_col(c).replace(".", "_") for c in columns] + if not isinstance(self.form_mapping_info, TrivialFormMappingInfo): + roots = {_.split("_", 1)[0] for _ in keys if "_" in _} + keys.extend([f"n{_}" for _ in roots]) return self.project_keys(keys) - def necessary_columns( - self, *, report: TypeTracerReport, state: dict - ) -> frozenset[str]: - ## Read from stash - # Form hierarchy information - form_key_to_parent_form_key: dict = state["trace"][ - "form_key_to_parent_form_key" - ] - # Buffer hierarchy information - form_key_to_buffer_keys: dict = state["trace"]["form_key_to_buffer_keys"] - # Restructured form information - form_info = state["form_info"] - - # Require the data of metadata buffers above shape-only requests - dask_awkward = uproot.extras.dask_awkward() - data_buffers = { - *report.data_touched, - *dask_awkward.lib.utils.buffer_keys_required_to_compute_shapes( - form_info.parse_buffer_key, - report.shape_touched, - form_key_to_parent_form_key, - form_key_to_buffer_keys, - ), - } - - # Determine which TTree keys need to be read - return form_info.keys_for_buffer_keys(data_buffers) & frozenset( - self.common_keys - ) - def project_keys(self: T, keys: frozenset[str]) -> T: raise NotImplementedError