diff --git a/heparchy/read/hdf.py b/heparchy/read/hdf.py index 4739324..3223fbd 100644 --- a/heparchy/read/hdf.py +++ b/heparchy/read/hdf.py @@ -12,8 +12,8 @@ from functools import cached_property from collections.abc import Mapping from typing import ( - Any, List, Dict, Sequence, Type, Iterator, Union, Set, TypeVar, - Callable, Generic, Tuple) + Any, List, Dict, Sequence, Type, Iterator, Union, Set, + TypeVar, Callable, Generic, Tuple, Optional) import numpy as np import h5py @@ -35,6 +35,7 @@ def _export(procedure: ExportType) -> ExportType: _NOT_NUMPY_ERR = ValueError("Stored data type is corrupted.") +_NO_EVENT_ERR = AttributeError("Event reader not pointing to event.") _BUILTIN_PROPS: Set[str] = set() _BUILTIN_METADATA = { # TODO: work out non-hardcoded "HdfEventReader": {"num_pcls", "mask_keys"}, @@ -68,39 +69,18 @@ def _stored_keys(attrs: AttributeManager, key_attr_name: str) -> Iterator[str]: yield name -def _mask_iter(reader: HdfEventReader) -> Iterator[str]: - key_attr_name = "mask_keys" - grp = reader._grp - if key_attr_name not in grp.attrs: - dtype = np.dtype(" Iterator[str]: + yield from reader.keys() -def _custom_iter(reader: HdfEventReader) -> Iterator[str]: - key_attr_name = "custom_keys" - grp = reader._grp - if key_attr_name not in grp.attrs: - names = set(grp.keys()) - set(reader.masks.keys()) - for name in (names - _BUILTIN_PROPS): - yield name - else: - yield from _stored_keys(grp.attrs, key_attr_name) - - -def _meta_iter(reader: ReaderType) -> Iterator[str]: +def _meta_iter(reader: Group) -> Iterator[str]: key_attr_name = "custom_meta_keys" - grp = reader._grp - if key_attr_name not in grp.attrs: - names = set(grp.attrs.keys()) + if key_attr_name not in reader.attrs: + names = set(reader.attrs.keys()) for name in (names - _BUILTIN_METADATA[reader.__class__.__name__]): yield name else: - yield from _stored_keys(grp.attrs, key_attr_name) + yield from _stored_keys(reader.attrs, key_attr_name) @_export @@ -133,10 +113,16 @@ class MapReader(Generic[MapValue], Mapping[str, MapValue]): """ def __init__(self, reader: ReaderType, + grp_attr_name: str, iter_func: Callable[..., Iterator[str]]) -> None: self._reader = reader + self._grp_attr_name = grp_attr_name self._iter_func = iter_func + @property + def _grp(self) -> Group: + return getattr(self._reader, self._grp_attr_name) + def __repr__(self) -> str: dset_repr = "" kv = ", ".join(map(lambda k: f"\'{k}\': {dset_repr}", self)) @@ -149,8 +135,8 @@ def __getitem__(self, name: str) -> MapValue: if name not in set(self): raise KeyError("No data stored with this name") if self._iter_func.__name__ == "_meta_iter": - return self._reader._grp.attrs[name] # type: ignore - data = self._reader._grp[name] + return self._grp.attrs[name] # type: ignore + data = self._grp[name] if not isinstance(data, Dataset): raise _NOT_NUMPY_ERR return data[...] # type: ignore @@ -162,7 +148,7 @@ def __delitem__(self, name: str) -> None: raise ReadOnlyError("Value is read-only") def __iter__(self) -> Iterator[str]: - yield from self._iter_func(self._reader) + yield from self._iter_func(self._grp) def _type_error_str(data: Any, dtype: type) -> str: @@ -231,11 +217,37 @@ class HdfEventReader(EventReaderBase): Read-only dictionary-like interface to access user-defined metadata on the event. """ - __slots__ = ('_name', '_grp') + def __init__(self) -> None: + self.__name: Optional[str] = None + self.__grp: Optional[Group] = None + self._custom_grp: Group + self._mask_grp: Group + + @property + def _name(self) -> str: + if self.__name is None: + raise _NO_EVENT_ERR + return self.__name - def __init__(self, evt_data) -> None: - self._name: str = evt_data[0] - self._grp: Group = evt_data[1] + @_name.setter + def _name(self, val: str) -> None: + self.__name = val + + @property + def _grp(self) -> Group: + if self.__grp is None: + raise _NO_EVENT_ERR + return self.__grp + + @_grp.setter + def _grp(self, val: Group) -> None: + self.__grp = val + custom = val["custom"] + mask = val["masks"] + if (not isinstance(custom, Group)) or (not isinstance(mask, Group)): + raise ValueError("Group not found") + self._custom_grp = custom + self._mask_grp = mask @property def name(self) -> str: @@ -322,7 +334,7 @@ def mask(self, name: str) -> BoolVector: @cached_property def masks(self) -> MapReader[BoolVector]: - return MapReader[BoolVector](self, _mask_iter) + return MapReader[BoolVector](self, "_mask_grp", _grp_key_iter) @deprecated def get_custom(self, name: str) -> AnyVector: @@ -331,7 +343,7 @@ def get_custom(self, name: str) -> AnyVector: @cached_property def custom(self) -> MapReader[AnyVector]: - return MapReader[AnyVector](self, _custom_iter) + return MapReader[AnyVector](self, "_custom_grp", _grp_key_iter) @deprecated def get_custom_meta(self, name: str) -> Any: @@ -340,7 +352,7 @@ def get_custom_meta(self, name: str) -> Any: @cached_property def custom_meta(self) -> MapReader[Any]: - return MapReader[Any](self, _meta_iter) + return MapReader[Any](self, "_grp", _meta_iter) def copy(self) -> HdfEventReader: """Returns a deep copy of the event object.""" @@ -361,8 +373,8 @@ class HdfProcessReader(ProcessReaderBase): key : str The name of the process to be opened. - Attributes (read-only) - ---------------------- + Attributes + ---------- process_string : str The MadGraph string representation of the process. string : str, deprecated @@ -404,13 +416,13 @@ class HdfProcessReader(ProcessReaderBase): """ def __init__(self, file_obj: HdfReader, key: str) -> None: - self._evt = HdfEventReader(evt_data=(None, None)) + self._evt = HdfEventReader() grp = file_obj._buffer[key] if not isinstance(grp, Group): raise KeyError(f"{key} is not a process") self._grp: Group = grp self._meta: MetaDictType = dict(file_obj._buffer[key].attrs) - self.custom_meta: MapReader[Any] = MapReader(self, _meta_iter) + self.custom_meta: MapReader[Any] = MapReader(self, "_grp", _meta_iter) def __len__(self) -> int: return int(self._meta["num_evts"]) diff --git a/heparchy/write/hdf.py b/heparchy/write/hdf.py index fc0f014..3ee1d07 100644 --- a/heparchy/write/hdf.py +++ b/heparchy/write/hdf.py @@ -12,6 +12,7 @@ from pathlib import Path import warnings from enum import Enum +from functools import partial import numpy.typing as npt @@ -63,26 +64,61 @@ class OverwriteWarning(RuntimeWarning): """ +def _mk_dset(grp: Group, name: str, data: AnyVector, shape: tuple, + dtype: npt.DTypeLike, compression: Compression, + compression_level: Optional[int]) -> None: + """Generic dataset creation and population function. + Wrap in methods exposed to the user interface. + """ + if name in grp: + warnings.warn(f"Overwriting {name}", OverwriteWarning) + del grp[name] + # check data can be broadcasted to dataset: + if data.squeeze().shape != shape: + raise ValueError( + f"Input data shape {data.shape} " + f"incompatible with dataset shape {shape}." + ) + kwargs: Dict[str, Any] = dict( + name=name, + shape=shape, + dtype=dtype, + shuffle=True, + compression=compression.value, + ) + cmprs_lvl = compression_level + if cmprs_lvl is not None: + kwargs["compression_opts"] = cmprs_lvl + dset = grp.create_dataset(**kwargs) + dset[...] = data + + def _mask_setter(writer: WriterType, name: str, data: BoolVector) -> None: if not isinstance(writer, HdfEventWriter): raise ValueError("Can't set masks on processes") writer._set_num_pcls(data) - writer._mk_dset( + _mk_dset( + writer._mask_grp, name=name, data=data, - shape=(writer.num_pcls,), + shape=data.shape, dtype=writer._types.bool, + compression=writer._proc._file_obj._cmprs, + compression_level=writer._proc._file_obj._cmprs_lvl, ) def _custom_setter(writer: WriterType, name: str, data: AnyVector) -> None: if not isinstance(writer, HdfEventWriter): raise ValueError("Can't set custom datasets on processes") - writer._mk_dset( + _mk_dset( + writer._custom_grp, name=name, data=data, shape=data.shape, dtype=data.dtype, + compression=writer._proc._file_obj._cmprs, + compression_level=writer._proc._file_obj._cmprs_lvl, ) @@ -215,6 +251,8 @@ def __init__(self, proc: HdfProcessWriter) -> None: self._num_pcls: Optional[int] = None self._num_edges = 0 self._grp: Group + self._custom_grp: Group + self._mask_grp: Group self.masks: MapWriter[BoolVector] self.custom: MapWriter[AnyVector] self.custom_meta: MapWriter[Any] @@ -222,9 +260,17 @@ def __init__(self, proc: HdfProcessWriter) -> None: def __enter__(self: HdfEventWriter) -> HdfEventWriter: self._grp = self._proc._grp.create_group( event_key_format(self._idx)) + self._custom_grp = self._grp.create_group("custom") + self._mask_grp = self._grp.create_group("masks") self.masks = MapWriter(self, _mask_setter) self.custom = MapWriter(self, _custom_setter) self.custom_meta = MapWriter(self, _meta_setter) + self._mk_dset = partial( + _mk_dset, + grp=self._grp, + compression=self._proc._file_obj._cmprs, + compression_level=self._proc._file_obj._cmprs_lvl, + ) return self def __exit__(self, exc_type, exc_value, exc_traceback) -> None: @@ -254,35 +300,6 @@ def _set_num_pcls(self, data: AnyVector) -> None: else: return - def _mk_dset( - self, name: str, data: AnyVector, shape: tuple, - dtype: npt.DTypeLike, is_mask: bool = False) -> None: - """Generic dataset creation and population function. - Wrap in methods exposed to the user interface. - """ - if name in self._grp: - warnings.warn(f"Overwriting {name}", OverwriteWarning) - del self._grp[name] - # check data can be broadcasted to dataset: - if data.squeeze().shape != shape: - raise ValueError( - f"Input data shape {data.shape} " - f"incompatible with dataset shape {shape}." - ) - kwargs: Dict[str, Any] = dict( - name=name, - shape=shape, - dtype=dtype, - shuffle=True, - compression=self._proc._file_obj._cmprs.value, - ) - cmprs_lvl = self._proc._file_obj._cmprs_lvl - if cmprs_lvl is not None: - kwargs["compression_opts"] = cmprs_lvl - dset = self._grp.create_dataset(**kwargs) - dset[...] = data - dset.attrs["mask"] = is_mask - @property def edges(self) -> AnyVector: raise WriteOnlyError(_WRITE_ONLY_MSG)