Skip to content

Commit

Permalink
feat: support precision for nequip (#18)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 3, 2024
1 parent 32061d5 commit 482dcc5
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 31 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ Below is default values for the MACE model, most of which follows default values
"irreps_edge_sh": "0e + 1e",
"feature_irreps_hidden": "32x0o + 32x0e + 32x1o + 32x1e",
"chemical_embedding_irreps_out": "32x0e",
"conv_to_output_hidden_irreps_out": "16x0e"
"conv_to_output_hidden_irreps_out": "16x0e",
"precision": "float32"
}
```

Expand Down
8 changes: 8 additions & 0 deletions deepmd_gnn/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def nequip_model_args() -> Argument:
doc_feature_irreps_hidden = "irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster"
doc_chemical_embedding_irreps_out = "irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer"
doc_conv_to_output_hidden_irreps_out = "irreps used in hidden layer of output block"
doc_precision = "Precision of the model, float32 or float64"
return Argument(
"nequip",
dict,
Expand Down Expand Up @@ -269,6 +270,13 @@ def nequip_model_args() -> Argument:
default="16x0e",
doc=doc_conv_to_output_hidden_irreps_out,
),
Argument(
"precision",
str,
optional=True,
default="float32",
doc=doc_precision,
),
],
doc="Nequip model",
)
33 changes: 3 additions & 30 deletions deepmd_gnn/nequip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Nequip model."""

import json
from copy import deepcopy
from typing import Any, Optional

Expand Down Expand Up @@ -118,6 +117,7 @@ def __init__(
feature_irreps_hidden: str = "32x0o + 32x0e + 32x1o + 32x1e",
chemical_embedding_irreps_out: str = "32x0e",
conv_to_output_hidden_irreps_out: str = "16x0e",
precision: str = "float32",
**kwargs: Any, # noqa: ANN401
) -> None:
super().__init__(**kwargs)
Expand All @@ -140,6 +140,7 @@ def __init__(
"feature_irreps_hidden": feature_irreps_hidden,
"chemical_embedding_irreps_out": chemical_embedding_irreps_out,
"conv_to_output_hidden_irreps_out": conv_to_output_hidden_irreps_out,
"precision": precision,
}
self.type_map = type_map
self.ntypes = len(type_map)
Expand Down Expand Up @@ -178,6 +179,7 @@ def __init__(
"feature_irreps_hidden": feature_irreps_hidden,
"chemical_embedding_irreps_out": chemical_embedding_irreps_out,
"conv_to_output_hidden_irreps_out": conv_to_output_hidden_irreps_out,
"model_dtype": precision,
},
),
)
Expand Down Expand Up @@ -694,32 +696,3 @@ def translated_output_def(self) -> dict[str, Any]:
def model_output_def(self) -> ModelOutputDef:
"""Get the output def for the model."""
return ModelOutputDef(self.fitting_output_def())

@classmethod
def get_model(cls, model_params: dict) -> "NequipModel":
"""Get the model by the parameters.
Parameters
----------
model_params : dict
The model parameters
Returns
-------
BaseBaseModel
The model
"""
model_params_old = model_params.copy()
model_params = model_params.copy()
model_params.pop("type", None)
precision = model_params.pop("precision", "float32")
if precision == "float32":
torch.set_default_dtype(torch.float32)
elif precision == "float64":
torch.set_default_dtype(torch.float64)
else:
msg = f"precision {precision} not supported"
raise ValueError(msg)
model = cls(**model_params)
model.model_def_script = json.dumps(model_params_old)
return model
2 changes: 2 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,7 @@ def setUpClass(cls) -> None:
cls.module = MaceModel(
type_map=cls.expected_type_map,
sel=138,
precision="float64",
)
with torch.jit.optimized_execution(should_optimize=False):
cls._script_module = torch.jit.script(cls.module)
Expand Down Expand Up @@ -1105,6 +1106,7 @@ def setUpClass(cls) -> None:
sel=138,
r_max=cls.expected_rcut,
num_layers=2,
precision="float64",
)
with torch.jit.optimized_execution(should_optimize=False):
cls._script_module = torch.jit.script(cls.module)
Expand Down

0 comments on commit 482dcc5

Please sign in to comment.