Skip to content

Commit c461541

Browse files
committed
Make WaterModel run on GPU
1 parent 8e0f5e8 commit c461541

File tree

1 file changed

+23
-31
lines changed

1 file changed

+23
-31
lines changed

examples/water-model/water-model.py

+23-31
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,8 @@ def lennard_jones_pair(
463463
# the system. Thanks to the fact we rely on ``torch`` autodifferentiation mechanism, the
464464
# forces acting on the virtual sites will be automatically split between O and H atoms,
465465
# in a way that is consistent with the definition.
466-
467-
# forces acting on the M sites will be automatically split between O and H atoms, in a
466+
#
467+
# Forces acting on the M sites will be automatically split between O and H atoms, in a
468468
# way that is consistent with the definition.
469469

470470

@@ -583,7 +583,8 @@ def get_molecular_geometry(
583583
)
584584

585585
tensor = TensorMap(
586-
keys=Labels("_", torch.zeros(1, 1, dtype=torch.int32)), blocks=[data]
586+
keys=Labels("_", torch.zeros(1, 1, device=charges.device, dtype=torch.int32)),
587+
blocks=[data],
587588
)
588589

589590
m_system.add_data(name="charges", tensor=tensor)
@@ -637,8 +638,6 @@ def __init__(
637638
hoh_angle_eq: float,
638639
hoh_angle_k: float,
639640
p3m_options: Optional[dict] = None,
640-
dtype: Optional[torch.dtype] = None,
641-
device: Optional[torch.device] = None,
642641
):
643642
super().__init__()
644643

@@ -653,34 +652,24 @@ def __init__(
653652
**p3m_parameters,
654653
prefactor=torchpme.prefactors.kcalmol_A, # consistent units
655654
)
656-
self.p3m_calculator.to(device=device, dtype=dtype)
657655

658656
self.coulomb = torchpme.CoulombPotential()
659-
self.coulomb.to(device=device, dtype=dtype)
660657

661658
# We use a half neighborlist and allow to have pairs farther than cutoff
662659
# (`strict=False`) since this is not problematic for PME and may speed up the
663660
# computation of the neigbors.
664661
self.nlo = NeighborListOptions(cutoff=cutoff, full_list=False, strict=False)
665662

666-
# registers model parameters as buffers
667-
self.dtype = dtype if dtype is not None else torch.get_default_dtype()
668-
self.device = device if device is not None else torch.get_default_device()
669-
670-
self.register_buffer("cutoff", torch.tensor(cutoff, dtype=self.dtype))
671-
self.register_buffer("lj_sigma", torch.tensor(lj_sigma, dtype=self.dtype))
672-
self.register_buffer("lj_epsilon", torch.tensor(lj_epsilon, dtype=self.dtype))
673-
self.register_buffer("m_gamma", torch.tensor(m_gamma, dtype=self.dtype))
674-
self.register_buffer("m_charge", torch.tensor(m_charge, dtype=self.dtype))
675-
self.register_buffer("oh_bond_eq", torch.tensor(oh_bond_eq, dtype=self.dtype))
676-
self.register_buffer("oh_bond_k", torch.tensor(oh_bond_k, dtype=self.dtype))
677-
self.register_buffer(
678-
"oh_bond_alpha", torch.tensor(oh_bond_alpha, dtype=self.dtype)
679-
)
680-
self.register_buffer(
681-
"hoh_angle_eq", torch.tensor(hoh_angle_eq, dtype=self.dtype)
682-
)
683-
self.register_buffer("hoh_angle_k", torch.tensor(hoh_angle_k, dtype=self.dtype))
663+
self.register_buffer("cutoff", torch.tensor(cutoff))
664+
self.register_buffer("lj_sigma", torch.tensor(lj_sigma))
665+
self.register_buffer("lj_epsilon", torch.tensor(lj_epsilon))
666+
self.register_buffer("m_gamma", torch.tensor(m_gamma))
667+
self.register_buffer("m_charge", torch.tensor(m_charge))
668+
self.register_buffer("oh_bond_eq", torch.tensor(oh_bond_eq))
669+
self.register_buffer("oh_bond_k", torch.tensor(oh_bond_k))
670+
self.register_buffer("oh_bond_alpha", torch.tensor(oh_bond_alpha))
671+
self.register_buffer("hoh_angle_eq", torch.tensor(hoh_angle_eq))
672+
self.register_buffer("hoh_angle_k", torch.tensor(hoh_angle_k))
684673

685674
def requested_neighbor_lists(self):
686675
"""Returns the list of neighbor list options that are needed."""
@@ -778,25 +767,28 @@ def forward(
778767

779768
# Rename property label to follow metatensor's convention for an atomistic model
780769
samples = Labels(
781-
["system"], torch.arange(len(systems), device=self.device).reshape(-1, 1)
770+
["system"],
771+
torch.arange(len(systems), device=energy_tot.device).reshape(-1, 1),
782772
)
773+
properties = Labels(["energy"], torch.tensor([[0]], device=energy_tot.device))
774+
783775
block = TensorBlock(
784776
values=torch.sum(energy_tot).reshape(-1, 1),
785777
samples=samples,
786778
components=torch.jit.annotate(List[Labels], []),
787-
properties=Labels(["energy"], torch.tensor([[0]], device=self.device)),
779+
properties=properties,
788780
)
789781
return {
790782
"energy": TensorMap(
791-
Labels("_", torch.tensor([[0]], device=self.device)), [block]
783+
Labels("_", torch.tensor([[0]], device=energy_tot.device)), [block]
792784
),
793785
}
794786

795787

796788
# %%
797789
#
798790
# All this class does is take a ``System`` and return its energy (as a
799-
# :clas:`metatensor.TensorMap``).
791+
# :class:`metatensor.TensorMap``).
800792

801793
qtip4pf_parameters = dict(
802794
cutoff=7.0,
@@ -861,7 +853,7 @@ def forward(
861853
# Model options include a definition of its units, and a description of the quantities
862854
# it can compute.
863855
#
864-
# .. note::
856+
# .. note::
865857
#
866858
# We neeed to specify that the model has infinite interaction range because of the
867859
# presence of a long-range term that means one cannot assume that forces decay to zero
@@ -876,7 +868,7 @@ def forward(
876868
atomic_types=[1, 8],
877869
interaction_range=torch.inf,
878870
length_unit=length_unit,
879-
supported_devices=["cpu", "cuda"],
871+
supported_devices=["cuda", "cpu"],
880872
dtype="float32",
881873
)
882874

0 commit comments

Comments
 (0)