Skip to content

Commit

Permalink
Update to be backwards compatible with python 3.8 and 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
nnn911 committed Apr 29, 2024
1 parent 51c5f1a commit fb76461
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- name: Install apt dependencies
Expand Down
86 changes: 49 additions & 37 deletions src/scoreBasedDenoising/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,30 +110,36 @@ def getModelPath(self):
raise FileNotFoundError(f"{path} does not exist.")
return Path(self.model_path)
modelDir = impRes.files("graphite.pretrained_models.denoiser")
match self.structure:
case "SiO2":
return modelDir.joinpath("SiO2-denoiser.pt")
case "FCC" | "BCC" | "HCP":
return modelDir.joinpath("Cu-denoiser.pt")
case _:
raise NotImplementedError(
f"No default model path available for: {self.structure}"
)
if self.structure == "SiO2":
return modelDir.joinpath("SiO2-denoiser.pt")
elif (
self.structure == "FCC"
or self.structure == "BCC"
or self.structure == "HCP"
):
return modelDir.joinpath("Cu-denoiser.pt")
else:
raise NotImplementedError(
f"No default model path available for: {self.structure}"
)

def estimateNearestNeighborsDistance(self, data):
finder = NearestNeighborFinder(
ScoreBasedDenoising.numNearestNeigh[self.structure], data
)
match self.structure:
case "SiO2":
idx = np.where(
data.particles["Particle Type"]
== data.particles["Particle Type"].type_by_name("Si").id
)[0]
case "FCC" | "BCC" | "HCP":
idx = None
case _:
raise NotImplementedError
if self.structure == "SiO2":
idx = np.where(
data.particles["Particle Type"]
== data.particles["Particle Type"].type_by_name("Si").id
)[0]
elif (
self.structure == "FCC"
or self.structure == "BCC"
or self.structure == "HCP"
):
idx = None
else:
raise NotImplementedError
_, neighVec = finder.find_all(idx)
return np.mean(np.linalg.norm(neighVec, axis=2))

Expand Down Expand Up @@ -180,15 +186,18 @@ def writeTable(data, y, ylabel, title):
table.y = table.create_property(ylabel, data=y)

def _modify(self, data, frame, **kwargs):
match self.structure:
case "SiO2":
model = self.setupSiO2model(data)
case "FCC" | "BCC" | "HCP":
model = self.setupFccBccHcpModel(data)
case "Custom":
model = self.setupCustomModel()
case _:
raise NotImplementedError
if self.structure == "SiO2":
model = self.setupSiO2model(data)
elif (
self.structure == "FCC"
or self.structure == "BCC"
or self.structure == "HCP"
):
model = self.setupFccBccHcpModel(data)
elif self.structure == "Custom":
model = self.setupCustomModel()
else:
raise NotImplementedError

model = model.to(self.device)
model.eval()
Expand All @@ -214,15 +223,18 @@ def _modify(self, data, frame, **kwargs):
)
data.particles_["Position_"][...] = denoised_atoms

match self.structure:
case "SiO2":
pass
case "FCC" | "BCC" | "HCP":
self.teardownFccBccHcpModel(data)
case "Custom":
pass
case _:
raise NotImplementedError
if self.structure == "SiO2":
pass
elif (
self.structure == "FCC"
or self.structure == "BCC"
or self.structure == "HCP"
):
self.teardownFccBccHcpModel(data)
elif self.structure == "Custom":
pass
else:
raise NotImplementedError

ScoreBasedDenoising.writeTable(data, convergence, "Convergence", "Convergence")
ScoreBasedDenoising.writeTable(
Expand Down

0 comments on commit fb76461

Please sign in to comment.