Skip to content

Commit

Permalink
Start custom model support
Browse files Browse the repository at this point in the history
  • Loading branch information
nnn911 committed Aug 18, 2023
1 parent f43c8de commit 1a7a0c1
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "scoreBasedDenoising"
version = "2023.7"
version = "2023.8"
description = "Score-based denoising for atomic structure identification - Iteratively subtract thermal noises or perturbations from atomic positions."
keywords = ["ovito", "ovito-extension"]
authors = [{ name = "Daniel Utt", email = "[email protected]" }]
Expand Down
67 changes: 41 additions & 26 deletions src/scoreBasedDenoising/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from ovito.pipeline import ModifierInterface
from sklearn.preprocessing import LabelEncoder
from torch_geometric.data import Data
from traits.api import Bool, Enum, Float, Int, Union
from traits.api import Bool, Enum, Float, Int, Union, Str
from pathlib import Path


# Make InitialEmbedding visible to torch.load() (pre-trained models expect Initial embedding to be part of main)
setattr(sys.modules["__main__"], "InitialEmbedding", InitialEmbedding)
Expand All @@ -47,9 +49,17 @@ class ScoreBasedDenoising(ModifierInterface):
scale = Union(None, Float, label="Nearest neighbor distance")

structure = Enum(
"None", "FCC", "BCC", "HCP", "SiO2", label="Crystal structure / material system"
"None",
"FCC",
"BCC",
"HCP",
"SiO2",
"Custom",
label="Crystal structure / material system",
)

modelPath = Union(None, Str, label="Model file path")

if torch.cuda.is_available():
device = Enum("cpu", "cuda", label="Device")
elif torch.backends.mps.is_available():
Expand All @@ -63,23 +73,6 @@ class ScoreBasedDenoising(ModifierInterface):
def getRadiusGraph():
return PeriodicRadiusGraph(cutoff=ScoreBasedDenoising.cutoff)

@staticmethod
def getModel(numSpecies):
return NequIP(
init_embed=InitialEmbedding(
num_species=numSpecies,
cutoff=ScoreBasedDenoising.cutoff,
),
irreps_node_x="8x0e",
irreps_node_z="8x0e",
irreps_hidden="8x0e + 8x1e + 4x2e",
irreps_edge="1x0e + 1x1e + 1x2e",
irreps_out="1x1e",
num_convs=3,
radial_neurons=[16, 64],
num_neighbors=12,
)

@torch.no_grad()
def denoise_snapshot(self, atoms, model, scale):
x = LabelEncoder().fit_transform(atoms.numbers)
Expand Down Expand Up @@ -110,6 +103,23 @@ def denoise_snapshot(self, atoms, model, scale):
yield
return data.pos.to("cpu").numpy() / scale, convergence

def getModelPath(self):
if self.modelPath is not None:
path = Path(self.modelPath)
if not path.exists():
raise FileNotFoundError(f"{path} does not exist.")
return Path(self.modelPath)
modelDir = impRes.files("graphite.pretrained_models.denoiser")
match self.structure:
case "SiO2":
return modelDir.joinpath("SiO2-denoiser.pt")
case "FCC" | "BCC" | "HCP":
return modelDir.joinpath("Cu-denoiser.pt")
case _:
raise NotImplementedError(
f"No default model path available for: {self.structure}"
)

def estimateNearestNeighborsDistance(self, data):
finder = NearestNeighborFinder(
ScoreBasedDenoising.numNearestNeigh[self.structure], data
Expand All @@ -128,8 +138,7 @@ def estimateNearestNeighborsDistance(self, data):
return np.mean(np.linalg.norm(neighVec, axis=2))

def setupSiO2model(self, data):
modelDir = impRes.files("graphite.pretrained_models.denoiser")
model = torch.load(modelDir.joinpath("SiO2-denoiser.pt"))
model = torch.load(self.getModelPath(), map_location=torch.device(self.device))
cts = {"Si": 0, "O": 0}
for uni in np.unique(data.particles["Particle Type"]):
name = data.particles["Particle Type"].type_by_id(uni).name
Expand All @@ -146,14 +155,16 @@ def setupSiO2model(self, data):
return model

def setupFccBccHcpModel(self, data):
modelDir = impRes.files("graphite.pretrained_models.denoiser")
model = torch.load(modelDir.joinpath("Cu-denoiser.pt"))
model = torch.load(self.getModelPath(), map_location=torch.device(self.device))
data.particles_.create_property(
"Particle Type Backup", data=data.particles["Particle Type"]
)
data.particles_["Particle Type_"][...] = 1
return model

def setupCustomModel(self):
return torch.load(self.getModelPath(), map_location=torch.device(self.device))

def teardownFccBccHcpModel(self, data):
data.particles_["Particle Type_"][...] = data.particles["Particle Type Backup"]
del data.particles_["Particle Type Backup"]
Expand All @@ -168,12 +179,14 @@ def writeTable(data, y, ylabel, title):
table.x = table.create_property("Step", data=np.arange(len(y)))
table.y = table.create_property(ylabel, data=y)

def run(self, data, frame, **kwargs):
def _modify(self, data, frame, **kwargs):
match self.structure:
case "SiO2":
model = self.setupSiO2model(data)
case "FCC" | "BCC" | "HCP":
model = self.setupFccBccHcpModel(data)
case "Custom":
model = self.setupCustomModel(data)
case _:
raise NotImplementedError

Expand All @@ -199,6 +212,8 @@ def run(self, data, frame, **kwargs):
pass
case "FCC" | "BCC" | "HCP":
self.teardownFccBccHcpModel(data)
case "Custom":
pass
case _:
raise NotImplementedError

Expand Down Expand Up @@ -226,7 +241,7 @@ def modify(self, data, frame, **kwargs):
data_clone.apply(InvertSelectionModifier())
data_clone.apply(DeleteSelectedModifier())

yield from self.run(data_clone, frame, **kwargs)
yield from self._modify(data_clone, frame, **kwargs)

data.particles_["Position_"][
data.particles["Selection"] == 1
Expand All @@ -236,4 +251,4 @@ def modify(self, data, frame, **kwargs):
for t in data_clone.tables:
data.objects.append(data_clone.tables[t])
else:
yield from self.run(data, frame, **kwargs)
yield from self._modify(data, frame, **kwargs)

0 comments on commit 1a7a0c1

Please sign in to comment.