Skip to content

Commit

Permalink
feat: nequip model in batch
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 7, 2024
1 parent f4c6b22 commit 87f687d
Showing 1 changed file with 100 additions and 100 deletions.
200 changes: 100 additions & 100 deletions deepmd_gnn/nequip.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
build_neighbor_list,
extend_input_and_build_neighbor_list,
)
from deepmd.pt.utils.region import (
phys2inter,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
Expand Down Expand Up @@ -487,111 +484,114 @@ def forward_lower_common(
extended_atype = extended_atype.to(torch.int64)
nall = extended_coord.shape[1]

# loop on nf
energies = []
forces = []
virials = []
atom_energies = []
atomic_virials = []
for ff in range(nf):
extended_coord_ff = extended_coord[ff]
extended_atype_ff = extended_atype[ff]
nlist_ff = nlist[ff]
edge_index = torch.ops.deepmd_gnn.edge_index(
nlist_ff,
extended_atype_ff,
torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"),
)
edge_index = edge_index.T
# Nequip and MACE have different defination for edge_index
edge_index = edge_index[[1, 0]]

# nequip can convert dtype by itself
default_dtype = torch.float64
extended_coord_ff = extended_coord_ff.to(default_dtype)
extended_coord_ff.requires_grad_(True) # noqa: FBT003

input_dict = {
"pos": extended_coord_ff,
"edge_index": edge_index,
"atom_types": extended_atype_ff,
}
if box is not None and mapping is not None:
# pass box, map edge index to real
box_ff = box[ff].to(extended_coord_ff.device)
input_dict["cell"] = box_ff
input_dict["pbc"] = torch.zeros(
3,
dtype=torch.bool,
device=box_ff.device,
)
shifts_atoms = extended_coord_ff - extended_coord_ff[mapping[ff]]
shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]]
edge_index = mapping[ff][edge_index]
input_dict["edge_index"] = edge_index
edge_cell_shift = phys2inter(shifts, box_ff.view(3, 3))
input_dict["edge_cell_shift"] = edge_cell_shift

ret = self.model.forward(
input_dict,
)

atom_energy = ret["atomic_energy"]
if atom_energy is None:
msg = "atom_energy is None"
raise ValueError(msg)
atom_energy = atom_energy.view(1, nall).to(extended_coord_.dtype)[:, :nloc]
# adds e0
atom_energy = atom_energy + self.e0[extended_atype_ff[:nloc]].view(
1,
nloc,
).to(
atom_energy.dtype,
# fake as one frame
extended_coord_ff = extended_coord.view(nf * nall, 3)
extended_atype_ff = extended_atype.view(nf * nall)
edge_index = torch.ops.deepmd_gnn.edge_index(
nlist,
extended_atype,
torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"),
)
edge_index = edge_index.T
# Nequip and MACE have different defination for edge_index
edge_index = edge_index[[1, 0]]

# nequip can convert dtype by itself
default_dtype = torch.float64
extended_coord_ff = extended_coord_ff.to(default_dtype)
extended_coord_ff.requires_grad_(True) # noqa: FBT003

input_dict = {
"pos": extended_coord_ff,
"edge_index": edge_index,
"atom_types": extended_atype_ff,
}
if box is not None and mapping is not None:
# pass box, map edge index to real
box_ff = box.to(extended_coord_ff.device)
input_dict["cell"] = box_ff
input_dict["pbc"] = torch.zeros(
3,
dtype=torch.bool,
device=box_ff.device,
)
energy = torch.sum(atom_energy, dim=1).view(1, 1).to(extended_coord_.dtype)
grad_outputs: list[Optional[torch.Tensor]] = [
torch.ones_like(energy),
]
force = torch.autograd.grad(
outputs=[energy],
inputs=[extended_coord_ff],
grad_outputs=grad_outputs,
retain_graph=True,
create_graph=self.training,
)[0]
if force is None:
msg = "force is None"
raise ValueError(msg)
force = -force
atomic_virial = force.unsqueeze(-1).to(
extended_coord_.dtype,
) @ extended_coord_ff.unsqueeze(-2).to(
extended_coord_.dtype,
batch = torch.arange(nf, device=box_ff.device).repeat(nall)
input_dict["batch"] = batch
ptr = torch.arange(
start=0,
end=nf * nall + 1,
step=nall,
dtype=torch.int64,
device=batch.device,
)
force = force.view(1, nall, 3).to(extended_coord_.dtype)
virial = (
torch.sum(atomic_virial, dim=0).view(1, 9).to(extended_coord_.dtype)
input_dict["ptr"] = ptr
mapping_ff = mapping.view(nf * nall) + torch.arange(
0,
nf * nall,
nall,
dtype=mapping.dtype,
device=mapping.device,
).unsqueeze(-1).expand(nf, nall).reshape(-1)
shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff]
shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]]
edge_index = mapping_ff[edge_index]
input_dict["edge_index"] = edge_index
rec_cell, _ = torch.linalg.inv_ex(box.view(nf, 3, 3))
edge_cell_shift = torch.einsum(
"ni,nij->nj",
shifts,
rec_cell[batch[edge_index[0]]],
)
input_dict["edge_cell_shift"] = edge_cell_shift

ret = self.model.forward(
input_dict,
)

energies.append(energy)
forces.append(force)
virials.append(virial)
atom_energies.append(atom_energy)
atomic_virials.append(atomic_virial)
energies_t = torch.cat(energies, dim=0)
forces_t = torch.cat(forces, dim=0)
virials_t = torch.cat(virials, dim=0)
atom_energies_t = torch.cat(atom_energies, dim=0)
atomic_virials_t = torch.cat(atomic_virials, dim=0)
atom_energy = ret["atomic_energy"]
if atom_energy is None:
msg = "atom_energy is None"
raise ValueError(msg)
atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc]
# adds e0
atom_energy = atom_energy + self.e0[extended_atype[:, :nloc]].view(
nf,
nloc,
).to(
atom_energy.dtype,
)
energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype)
grad_outputs: list[Optional[torch.Tensor]] = [
torch.ones_like(energy),
]
force = torch.autograd.grad(
outputs=[energy],
inputs=[extended_coord_ff],
grad_outputs=grad_outputs,
retain_graph=True,
create_graph=self.training,
)[0]
if force is None:
msg = "force is None"
raise ValueError(msg)
force = -force
atomic_virial = force.unsqueeze(-1).to(
extended_coord_.dtype,
) @ extended_coord_ff.unsqueeze(-2).to(
extended_coord_.dtype,
)
force = force.view(nf, nall, 3).to(extended_coord_.dtype)
atomic_virial = atomic_virial.view(nf, nall, 1, 9)
virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype)

return {
"energy_redu": energies_t.view(nf, 1),
"energy_derv_r": forces_t.view(nf, nall, 1, 3),
"energy_derv_c_redu": virials_t.view(nf, 1, 9),
"energy_redu": energy.view(nf, 1),
"energy_derv_r": force.view(nf, nall, 1, 3),
"energy_derv_c_redu": virial.view(nf, 1, 9),
# take the first nloc atoms to match other models
"energy": atom_energies_t.view(nf, nloc, 1),
"energy": atom_energy.view(nf, nloc, 1),
# fake atom_virial
"energy_derv_c": atomic_virials_t.view(nf, nall, 1, 9),
"energy_derv_c": atomic_virial.view(nf, nall, 1, 9),
}

def serialize(self) -> dict:
Expand Down

0 comments on commit 87f687d

Please sign in to comment.