Skip to content

Commit

Permalink
Simplify implementation of PMGDir.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Oct 30, 2024
1 parent 5404615 commit bd9fba9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
26 changes: 13 additions & 13 deletions src/pymatgen/io/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,38 +475,38 @@ def reset(self):
changed.
"""
# Note that py3.12 has Path.walk(). But we need to use os.walk to ensure backwards compatibility for now.
self.files = [str((Path(d) / f).relative_to(self.path)) for d, _, fnames in os.walk(self.path) for f in fnames]

self._parsed_files: dict[str, Any] = {}
self._files: dict[str, Any] = {
str((Path(d) / f).relative_to(self.path)): None for d, _, fnames in os.walk(self.path) for f in fnames
}

def __contains__(self, item):
return item in self.files
return item in self._files

def __len__(self):
return len(self.files)
return len(self._files)

def __iter__(self):
return iter(self.files)
return iter(self._files)

def __getitem__(self, item):
if item in self._parsed_files:
return self._parsed_files[item]
if self._files.get(item):
return self._files.get(item)
fpath = self.path / item

if not (self.path / item).exists():
raise ValueError(f"{item} not found in {self.path}. List of files are {self.files}.")
raise ValueError(f"{item} not found in {self.path}. List of files are {self._files.keys()}.")

for k, cls_ in PMGDir.FILE_MAPPINGS.items():
if k in item:
modname, classname = cls_.rsplit(".", 1)
module = importlib.import_module(modname)
class_ = getattr(module, classname)
try:
self._parsed_files[item] = class_.from_file(fpath)
self._files[item] = class_.from_file(fpath)
except AttributeError:
self._parsed_files[item] = class_(fpath)
self._files[item] = class_(fpath)

return self._parsed_files[item]
return self._files[item]

warnings.warn(
f"No parser defined for {item}. Contents are returned as a string.",
Expand All @@ -522,7 +522,7 @@ def get_files_by_name(self, name: str) -> dict[str, Any]:
Returns:
{filename: object from PMGDir[filename]}
"""
return {f: self[f] for f in self.files if name in f}
return {f: self[f] for f in self._files if name in f}

def __repr__(self):
return f"PMGDir({self.path})"
5 changes: 3 additions & 2 deletions tests/io/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_getitem(self):

d = PMGDir(f"{TEST_FILES_DIR}/io/vasp/fixtures/scan_relaxation")
assert len(d) == 2
assert "vasprun.xml.gz" in d.files
assert "vasprun.xml.gz" in d
assert "OUTCAR" in d
assert d["vasprun.xml.gz"].incar["METAGGA"] == "R2scan"

Expand All @@ -57,4 +57,5 @@ def test_getitem(self):
assert all("OUTCAR" for k in outcars)

d.reset()
assert len(d._parsed_files) == 0
for v in d._files.values():
assert v is None

0 comments on commit bd9fba9

Please sign in to comment.