From fb764616652e6dd28cde995a992f36c664fc3c43 Mon Sep 17 00:00:00 2001 From: Daniel Utt Date: Mon, 29 Apr 2024 17:41:42 +0200 Subject: [PATCH] Update to be backwards compatible with python 3.8 and 3.9 --- .github/workflows/python-tests.yml | 2 +- src/scoreBasedDenoising/__init__.py | 86 ++++++++++++++++------------- 2 files changed, 50 insertions(+), 38 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index d0d4a9a..9cf0ad4 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -18,7 +18,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 - name: Install apt dependencies diff --git a/src/scoreBasedDenoising/__init__.py b/src/scoreBasedDenoising/__init__.py index dc64c02..2dcd948 100644 --- a/src/scoreBasedDenoising/__init__.py +++ b/src/scoreBasedDenoising/__init__.py @@ -110,30 +110,36 @@ def getModelPath(self): raise FileNotFoundError(f"{path} does not exist.") return Path(self.model_path) modelDir = impRes.files("graphite.pretrained_models.denoiser") - match self.structure: - case "SiO2": - return modelDir.joinpath("SiO2-denoiser.pt") - case "FCC" | "BCC" | "HCP": - return modelDir.joinpath("Cu-denoiser.pt") - case _: - raise NotImplementedError( - f"No default model path available for: {self.structure}" - ) + if self.structure == "SiO2": + return modelDir.joinpath("SiO2-denoiser.pt") + elif ( + self.structure == "FCC" + or self.structure == "BCC" + or self.structure == "HCP" + ): + return modelDir.joinpath("Cu-denoiser.pt") + else: + raise NotImplementedError( + f"No default model path available for: {self.structure}" + ) def estimateNearestNeighborsDistance(self, data): finder = NearestNeighborFinder( ScoreBasedDenoising.numNearestNeigh[self.structure], data ) - match self.structure: - case "SiO2": - idx = np.where( - data.particles["Particle Type"] - == data.particles["Particle Type"].type_by_name("Si").id - )[0] - case "FCC" | "BCC" | "HCP": - idx = None - case _: - raise NotImplementedError + if self.structure == "SiO2": + idx = np.where( + data.particles["Particle Type"] + == data.particles["Particle Type"].type_by_name("Si").id + )[0] + elif ( + self.structure == "FCC" + or self.structure == "BCC" + or self.structure == "HCP" + ): + idx = None + else: + raise NotImplementedError _, neighVec = finder.find_all(idx) return np.mean(np.linalg.norm(neighVec, axis=2)) @@ -180,15 +186,18 @@ def writeTable(data, y, ylabel, title): table.y = table.create_property(ylabel, data=y) def _modify(self, data, frame, **kwargs): - match self.structure: - case "SiO2": - model = self.setupSiO2model(data) - case "FCC" | "BCC" | "HCP": - model = self.setupFccBccHcpModel(data) - case "Custom": - model = self.setupCustomModel() - case _: - raise NotImplementedError + if self.structure == "SiO2": + model = self.setupSiO2model(data) + elif ( + self.structure == "FCC" + or self.structure == "BCC" + or self.structure == "HCP" + ): + model = self.setupFccBccHcpModel(data) + elif self.structure == "Custom": + model = self.setupCustomModel() + else: + raise NotImplementedError model = model.to(self.device) model.eval() @@ -214,15 +223,18 @@ def _modify(self, data, frame, **kwargs): ) data.particles_["Position_"][...] = denoised_atoms - match self.structure: - case "SiO2": - pass - case "FCC" | "BCC" | "HCP": - self.teardownFccBccHcpModel(data) - case "Custom": - pass - case _: - raise NotImplementedError + if self.structure == "SiO2": + pass + elif ( + self.structure == "FCC" + or self.structure == "BCC" + or self.structure == "HCP" + ): + self.teardownFccBccHcpModel(data) + elif self.structure == "Custom": + pass + else: + raise NotImplementedError ScoreBasedDenoising.writeTable(data, convergence, "Convergence", "Convergence") ScoreBasedDenoising.writeTable(