Skip to content

Commit

Permalink
Format black
Browse files Browse the repository at this point in the history
  • Loading branch information
pyiron-runner committed Nov 30, 2023
1 parent af5bd09 commit 5907595
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions pyiron_base/storage/hdfio.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@
# imported. We can work around this by defining here an explicit map that
# _to_object can use to find the new modules and update the HDF5 files
_MODULE_CONVERSION_DICT = {
"pyiron_base.generic.datacontainer": "pyiron_base.storage.datacontainer",
"pyiron_base.generic.inputlist": "pyiron_base.storage.inputlist",
"pyiron_base.generic.flattenedstorage": "pyiron_base.storage.flattenedstorage",
"pyiron_base.table.datamining": "pyiron_base.jobs.datamining",
"pyiron_base.generic.datacontainer": "pyiron_base.storage.datacontainer",
"pyiron_base.generic.inputlist": "pyiron_base.storage.inputlist",
"pyiron_base.generic.flattenedstorage": "pyiron_base.storage.flattenedstorage",
"pyiron_base.table.datamining": "pyiron_base.jobs.datamining",
}


def add_module_conversion_path(old: str, new: str):
"""
Add a new module conversion.
Expand All @@ -64,7 +65,10 @@ def add_module_conversion_path(old: str, new: str):
if old not in _MODULE_CONVERSION_DICT:
_MODULE_CONVERSION_DICT[old] = new
elif _MODULE_CONVERSION_DICT[old] != new:
raise ValueError(f"Module path '{old}' already found in conversion dict, pointing to '{new}'!")
raise ValueError(
f"Module path '{old}' already found in conversion dict, pointing to '{new}'!"
)


def patch_sys_module():
"""
Expand Down Expand Up @@ -119,7 +123,8 @@ def _import_class(class_name):
class_module_path = ".".join(class_path)
# ugly dynamic import, but only needed to log the warning anyway
from pyiron_base.jobs.job.jobtype import JobTypeChoice
job_class_dict = JobTypeChoice().job_class_dict # access global singleton

job_class_dict = JobTypeChoice().job_class_dict # access global singleton
if internal_class_name in job_class_dict:
module_path = job_class_dict[internal_class_name]
# entries in the job_class_dict are either strings of modules or fully
Expand Down Expand Up @@ -186,6 +191,7 @@ def _to_object(hdf, class_name=None, **kwargs):
obj.from_hdf(hdf=hdf.open(".."), group_name=hdf.h5_path.split("/")[-1])
return obj


def open_hdf5(filename, mode="r", swmr=False):
if swmr and mode != "r":
store = h5py.File(filename, mode=mode, libver="latest")
Expand All @@ -194,6 +200,7 @@ def open_hdf5(filename, mode="r", swmr=False):
else:
return h5py.File(filename, mode=mode, libver="latest", swmr=swmr)


class FileHDFio(HasGroups, MutableMapping):
"""
Class that provides all info to access a h5 file. This class is based on h5io.py, which allows to
Expand Down Expand Up @@ -1515,6 +1522,7 @@ def create_project_from_hdf5(self):
"""
return self._project.__class__(path=self.file_path)


class DummyHDFio(HasGroups):
"""
A dummy ProjectHDFio implementation to serialize objects into a dict
Expand Down Expand Up @@ -1636,8 +1644,7 @@ def create_group(self, name: str):
d = self._dict.get(name, None)
if d is None:
self._dict[name] = d = type(self)(
self.project,
os.path.join(self.h5_path, name), cont={}, root=self
self.project, os.path.join(self.h5_path, name), cont={}, root=self
)
elif isinstance(d, DummyHDFio):
pass
Expand All @@ -1649,7 +1656,11 @@ def _list_nodes(self):
return [k for k, v in self._dict.items() if not isinstance(v, DummyHDFio)]

def _list_groups(self):
return [k for k, v in self._dict.items() if isinstance(v, DummyHDFio) and not v._empty()]
return [
k
for k, v in self._dict.items()
if isinstance(v, DummyHDFio) and not v._empty()
]

def __contains__(self, item):
return item in self._dict
Expand Down Expand Up @@ -1708,6 +1719,7 @@ def unwrap(v):
if isinstance(v, DummyHDFio):
return v.to_dict()
return v

return {k: unwrap(v) for k, v in self._dict.items()}

def to_object(self, class_name=None, **kwargs):
Expand All @@ -1729,7 +1741,10 @@ def to_object(self, class_name=None, **kwargs):
def _empty(self):
if len(self._dict) == 0:
return True
return len(self.list_nodes())==0 and all(self[g]._empty() for g in self.list_groups())
return len(self.list_nodes()) == 0 and all(
self[g]._empty() for g in self.list_groups()
)


def _get_safe_filename(file_name):
file_path_no_ext, file_ext = os.path.splitext(file_name)
Expand Down

0 comments on commit 5907595

Please sign in to comment.