Skip to content

Commit

Permalink
breaking: dict interfaces now read / write from sub-groups
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Nov 10, 2022
1 parent fde249e commit e5182e7
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 75 deletions.
98 changes: 55 additions & 43 deletions heparchy/read/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from functools import cached_property
from collections.abc import Mapping
from typing import (
Any, List, Dict, Sequence, Type, Iterator, Union, Set, TypeVar,
Callable, Generic, Tuple)
Any, List, Dict, Sequence, Type, Iterator, Union, Set,
TypeVar, Callable, Generic, Tuple, Optional)

import numpy as np
import h5py
Expand All @@ -35,6 +35,7 @@ def _export(procedure: ExportType) -> ExportType:


_NOT_NUMPY_ERR = ValueError("Stored data type is corrupted.")
_NO_EVENT_ERR = AttributeError("Event reader not pointing to event.")
_BUILTIN_PROPS: Set[str] = set()
_BUILTIN_METADATA = { # TODO: work out non-hardcoded
"HdfEventReader": {"num_pcls", "mask_keys"},
Expand Down Expand Up @@ -68,39 +69,18 @@ def _stored_keys(attrs: AttributeManager, key_attr_name: str) -> Iterator[str]:
yield name


def _mask_iter(reader: HdfEventReader) -> Iterator[str]:
key_attr_name = "mask_keys"
grp = reader._grp
if key_attr_name not in grp.attrs:
dtype = np.dtype("<?")
for name, dset in grp.items():
if dset.dtype != dtype:
continue
yield name
else:
yield from _stored_keys(grp.attrs, key_attr_name)
def _grp_key_iter(reader: Group) -> Iterator[str]:
yield from reader.keys()


def _custom_iter(reader: HdfEventReader) -> Iterator[str]:
key_attr_name = "custom_keys"
grp = reader._grp
if key_attr_name not in grp.attrs:
names = set(grp.keys()) - set(reader.masks.keys())
for name in (names - _BUILTIN_PROPS):
yield name
else:
yield from _stored_keys(grp.attrs, key_attr_name)


def _meta_iter(reader: ReaderType) -> Iterator[str]:
def _meta_iter(reader: Group) -> Iterator[str]:
key_attr_name = "custom_meta_keys"
grp = reader._grp
if key_attr_name not in grp.attrs:
names = set(grp.attrs.keys())
if key_attr_name not in reader.attrs:
names = set(reader.attrs.keys())
for name in (names - _BUILTIN_METADATA[reader.__class__.__name__]):
yield name
else:
yield from _stored_keys(grp.attrs, key_attr_name)
yield from _stored_keys(reader.attrs, key_attr_name)


@_export
Expand Down Expand Up @@ -133,10 +113,16 @@ class MapReader(Generic[MapValue], Mapping[str, MapValue]):
"""
def __init__(self,
reader: ReaderType,
grp_attr_name: str,
iter_func: Callable[..., Iterator[str]]) -> None:
self._reader = reader
self._grp_attr_name = grp_attr_name
self._iter_func = iter_func

@property
def _grp(self) -> Group:
return getattr(self._reader, self._grp_attr_name)

def __repr__(self) -> str:
dset_repr = "<Read-Only Data>"
kv = ", ".join(map(lambda k: f"\'{k}\': {dset_repr}", self))
Expand All @@ -149,8 +135,8 @@ def __getitem__(self, name: str) -> MapValue:
if name not in set(self):
raise KeyError("No data stored with this name")
if self._iter_func.__name__ == "_meta_iter":
return self._reader._grp.attrs[name] # type: ignore
data = self._reader._grp[name]
return self._grp.attrs[name] # type: ignore
data = self._grp[name]
if not isinstance(data, Dataset):
raise _NOT_NUMPY_ERR
return data[...] # type: ignore
Expand All @@ -162,7 +148,7 @@ def __delitem__(self, name: str) -> None:
raise ReadOnlyError("Value is read-only")

def __iter__(self) -> Iterator[str]:
yield from self._iter_func(self._reader)
yield from self._iter_func(self._grp)


def _type_error_str(data: Any, dtype: type) -> str:
Expand Down Expand Up @@ -231,11 +217,37 @@ class HdfEventReader(EventReaderBase):
Read-only dictionary-like interface to access user-defined
metadata on the event.
"""
__slots__ = ('_name', '_grp')
def __init__(self) -> None:
self.__name: Optional[str] = None
self.__grp: Optional[Group] = None
self._custom_grp: Group
self._mask_grp: Group

@property
def _name(self) -> str:
if self.__name is None:
raise _NO_EVENT_ERR
return self.__name

def __init__(self, evt_data) -> None:
self._name: str = evt_data[0]
self._grp: Group = evt_data[1]
@_name.setter
def _name(self, val: str) -> None:
self.__name = val

@property
def _grp(self) -> Group:
if self.__grp is None:
raise _NO_EVENT_ERR
return self.__grp

@_grp.setter
def _grp(self, val: Group) -> None:
self.__grp = val
custom = val["custom"]
mask = val["masks"]
if (not isinstance(custom, Group)) or (not isinstance(mask, Group)):
raise ValueError("Group not found")
self._custom_grp = custom
self._mask_grp = mask

@property
def name(self) -> str:
Expand Down Expand Up @@ -322,7 +334,7 @@ def mask(self, name: str) -> BoolVector:

@cached_property
def masks(self) -> MapReader[BoolVector]:
return MapReader[BoolVector](self, _mask_iter)
return MapReader[BoolVector](self, "_mask_grp", _grp_key_iter)

@deprecated
def get_custom(self, name: str) -> AnyVector:
Expand All @@ -331,7 +343,7 @@ def get_custom(self, name: str) -> AnyVector:

@cached_property
def custom(self) -> MapReader[AnyVector]:
return MapReader[AnyVector](self, _custom_iter)
return MapReader[AnyVector](self, "_custom_grp", _grp_key_iter)

@deprecated
def get_custom_meta(self, name: str) -> Any:
Expand All @@ -340,7 +352,7 @@ def get_custom_meta(self, name: str) -> Any:

@cached_property
def custom_meta(self) -> MapReader[Any]:
return MapReader[Any](self, _meta_iter)
return MapReader[Any](self, "_grp", _meta_iter)

def copy(self) -> HdfEventReader:
"""Returns a deep copy of the event object."""
Expand All @@ -361,8 +373,8 @@ class HdfProcessReader(ProcessReaderBase):
key : str
The name of the process to be opened.
Attributes (read-only)
----------------------
Attributes
----------
process_string : str
The MadGraph string representation of the process.
string : str, deprecated
Expand Down Expand Up @@ -404,13 +416,13 @@ class HdfProcessReader(ProcessReaderBase):
"""
def __init__(self, file_obj: HdfReader, key: str) -> None:
self._evt = HdfEventReader(evt_data=(None, None))
self._evt = HdfEventReader()
grp = file_obj._buffer[key]
if not isinstance(grp, Group):
raise KeyError(f"{key} is not a process")
self._grp: Group = grp
self._meta: MetaDictType = dict(file_obj._buffer[key].attrs)
self.custom_meta: MapReader[Any] = MapReader(self, _meta_iter)
self.custom_meta: MapReader[Any] = MapReader(self, "_grp", _meta_iter)

def __len__(self) -> int:
return int(self._meta["num_evts"])
Expand Down
81 changes: 49 additions & 32 deletions heparchy/write/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path
import warnings
from enum import Enum
from functools import partial


import numpy.typing as npt
Expand Down Expand Up @@ -63,26 +64,61 @@ class OverwriteWarning(RuntimeWarning):
"""


def _mk_dset(grp: Group, name: str, data: AnyVector, shape: tuple,
dtype: npt.DTypeLike, compression: Compression,
compression_level: Optional[int]) -> None:
"""Generic dataset creation and population function.
Wrap in methods exposed to the user interface.
"""
if name in grp:
warnings.warn(f"Overwriting {name}", OverwriteWarning)
del grp[name]
# check data can be broadcasted to dataset:
if data.squeeze().shape != shape:
raise ValueError(
f"Input data shape {data.shape} "
f"incompatible with dataset shape {shape}."
)
kwargs: Dict[str, Any] = dict(
name=name,
shape=shape,
dtype=dtype,
shuffle=True,
compression=compression.value,
)
cmprs_lvl = compression_level
if cmprs_lvl is not None:
kwargs["compression_opts"] = cmprs_lvl
dset = grp.create_dataset(**kwargs)
dset[...] = data


def _mask_setter(writer: WriterType, name: str, data: BoolVector) -> None:
if not isinstance(writer, HdfEventWriter):
raise ValueError("Can't set masks on processes")
writer._set_num_pcls(data)
writer._mk_dset(
_mk_dset(
writer._mask_grp,
name=name,
data=data,
shape=(writer.num_pcls,),
shape=data.shape,
dtype=writer._types.bool,
compression=writer._proc._file_obj._cmprs,
compression_level=writer._proc._file_obj._cmprs_lvl,
)


def _custom_setter(writer: WriterType, name: str, data: AnyVector) -> None:
if not isinstance(writer, HdfEventWriter):
raise ValueError("Can't set custom datasets on processes")
writer._mk_dset(
_mk_dset(
writer._custom_grp,
name=name,
data=data,
shape=data.shape,
dtype=data.dtype,
compression=writer._proc._file_obj._cmprs,
compression_level=writer._proc._file_obj._cmprs_lvl,
)


Expand Down Expand Up @@ -215,16 +251,26 @@ def __init__(self, proc: HdfProcessWriter) -> None:
self._num_pcls: Optional[int] = None
self._num_edges = 0
self._grp: Group
self._custom_grp: Group
self._mask_grp: Group
self.masks: MapWriter[BoolVector]
self.custom: MapWriter[AnyVector]
self.custom_meta: MapWriter[Any]

def __enter__(self: HdfEventWriter) -> HdfEventWriter:
self._grp = self._proc._grp.create_group(
event_key_format(self._idx))
self._custom_grp = self._grp.create_group("custom")
self._mask_grp = self._grp.create_group("masks")
self.masks = MapWriter(self, _mask_setter)
self.custom = MapWriter(self, _custom_setter)
self.custom_meta = MapWriter(self, _meta_setter)
self._mk_dset = partial(
_mk_dset,
grp=self._grp,
compression=self._proc._file_obj._cmprs,
compression_level=self._proc._file_obj._cmprs_lvl,
)
return self

def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
Expand Down Expand Up @@ -254,35 +300,6 @@ def _set_num_pcls(self, data: AnyVector) -> None:
else:
return

def _mk_dset(
self, name: str, data: AnyVector, shape: tuple,
dtype: npt.DTypeLike, is_mask: bool = False) -> None:
"""Generic dataset creation and population function.
Wrap in methods exposed to the user interface.
"""
if name in self._grp:
warnings.warn(f"Overwriting {name}", OverwriteWarning)
del self._grp[name]
# check data can be broadcasted to dataset:
if data.squeeze().shape != shape:
raise ValueError(
f"Input data shape {data.shape} "
f"incompatible with dataset shape {shape}."
)
kwargs: Dict[str, Any] = dict(
name=name,
shape=shape,
dtype=dtype,
shuffle=True,
compression=self._proc._file_obj._cmprs.value,
)
cmprs_lvl = self._proc._file_obj._cmprs_lvl
if cmprs_lvl is not None:
kwargs["compression_opts"] = cmprs_lvl
dset = self._grp.create_dataset(**kwargs)
dset[...] = data
dset.attrs["mask"] = is_mask

@property
def edges(self) -> AnyVector:
raise WriteOnlyError(_WRITE_ONLY_MSG)
Expand Down

0 comments on commit e5182e7

Please sign in to comment.