Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pt/dp): make strip more efficient #4400

Merged
merged 7 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def call(
extended_atype: np.ndarray,
extended_atype_embd: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
type_embedding: Optional[np.ndarray] = None,
):
"""Calculate DescriptorBlock."""
pass
Expand Down
83 changes: 61 additions & 22 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,10 @@ def call(
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
nf, nloc, nnei = nlist.shape
nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3
type_embedding = self.type_embedding.call()
# nf x nall x tebd_dim
atype_embd_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
(nf, nall, self.tebd_dim),
)
# nfnl x tebd_dim
Expand All @@ -507,6 +508,7 @@ def call(
atype_ext,
atype_embd_ext,
mapping=None,
type_embedding=type_embedding,
)
# nf x nloc x (ng x ng1 + tebd_dim)
if self.concat_output_tebd:
Expand Down Expand Up @@ -874,10 +876,6 @@ def cal_g_strip(
embedding_idx,
):
assert self.embeddings_strip is not None
xp = array_api_compat.array_namespace(ss)
nfnl, nnei = ss.shape[0:2]
shape2 = math.prod(ss.shape[2:])
ss = xp.reshape(ss, (nfnl, nnei, shape2))
# nfnl x nnei x ng
gg = self.embeddings_strip[embedding_idx].call(ss)
return gg
Expand All @@ -889,13 +887,15 @@ def call(
atype_ext: np.ndarray,
atype_embd_ext: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
type_embedding: Optional[np.ndarray] = None,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
):
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
# nf x nloc x nnei x 4
dmatrix, diff, sw = self.env_mat.call(
coord_ext, atype_ext, nlist, self.mean, self.stddev
)
nf, nloc, nnei, _ = dmatrix.shape
atype = atype_ext[:, :nloc]
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# nfnl x nnei
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
Expand All @@ -906,28 +906,33 @@ def call(
dmatrix = xp.reshape(dmatrix, (nf * nloc, nnei, 4))
# nfnl x nnei x 1
sw = xp.reshape(sw, (nf * nloc, nnei, 1))
# nfnl x tebd_dim
atype_embd = xp.reshape(atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim))
# nfnl x nnei x tebd_dim
atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1))
# nfnl x nnei
nlist_mask = nlist != -1
# nfnl x nnei x 1
sw = xp.where(nlist_mask[:, :, None], sw, xp.full_like(sw, 0.0))
nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist))
index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim))
# nfnl x nnei x tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
ng = self.neuron[-1]
nt = self.tebd_dim
# nfnl x nnei x 4
rr = xp.reshape(dmatrix, (nf * nloc, nnei, 4))
rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype)
# nfnl x nnei x 1
ss = rr[..., 0:1]
if self.tebd_input_mode in ["concat"]:
# nfnl x tebd_dim
atype_embd = xp.reshape(
atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim)
)
# nfnl x nnei x tebd_dim
atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1))
index = xp.tile(
xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)
)
# nfnl x nnei x tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
if not self.type_one_side:
# nfnl x nnei x (1 + 2 * tebd_dim)
ss = xp.concat([ss, atype_embd_nlist, atype_embd_nnei], axis=-1)
Expand All @@ -941,14 +946,48 @@ def call(
# nfnl x nnei x ng
gg_s = self.cal_g(ss, 0)
assert self.embeddings_strip is not None
if not self.type_one_side:
# nfnl x nnei x (tebd_dim * 2)
tt = xp.concat([atype_embd_nlist, atype_embd_nnei], axis=-1)
assert type_embedding is not None
ntypes_with_padding = type_embedding.shape[0]
iProzd marked this conversation as resolved.
Show resolved Hide resolved
# nf x (nl x nnei)
nlist_index = xp.reshape(nlist_masked, (nf, nloc * nnei))
# nf x (nl x nnei)
nei_type = xp_take_along_axis(atype_ext, nlist_index, axis=1)
# (nf x nl x nnei) x ng
nei_type_index = xp.tile(xp.reshape(nei_type, (-1, 1)), (1, ng))
if self.type_one_side:
tt_full = self.cal_g_strip(type_embedding, 0)
# (nf x nl x nnei) x ng
gg_t = xp_take_along_axis(tt_full, nei_type_index, axis=0)
else:
# nfnl x nnei x tebd_dim
tt = atype_embd_nlist
# nfnl x nnei x ng
gg_t = self.cal_g_strip(tt, 0)
idx_i = xp.reshape(
xp.tile(
(xp.reshape(atype, (-1, 1)) * ntypes_with_padding), (1, nnei)
),
(-1),
)
idx_j = xp.reshape(nei_type, (-1,))
# (nf x nl x nnei) x ng
idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng))
# (ntypes) * ntypes * nt
type_embedding_nei = xp.tile(
xp.reshape(type_embedding, (1, ntypes_with_padding, nt)),
(ntypes_with_padding, 1, 1),
)
# ntypes * (ntypes) * nt
type_embedding_center = xp.tile(
xp.reshape(type_embedding, (ntypes_with_padding, 1, nt)),
(1, ntypes_with_padding, 1),
)
# (ntypes * ntypes) * (nt+nt)
two_side_type_embedding = xp.reshape(
xp.concat([type_embedding_nei, type_embedding_center], axis=-1),
(-1, nt * 2),
)
tt_full = self.cal_g_strip(two_side_type_embedding, 0)
# (nf x nl x nnei) x ng
gg_t = xp_take_along_axis(tt_full, idx, axis=0)
# (nf x nl) x nnei x ng
iProzd marked this conversation as resolved.
Show resolved Hide resolved
gg_t = xp.reshape(gg_t, (nf * nloc, nnei, ng))
if self.smooth:
gg_t = gg_t * xp.reshape(sw, (-1, self.nnei, 1))
# nfnl x nnei x ng
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,9 +811,10 @@ def call(
self.rcut_list,
self.nsel_list,
)
type_embedding = self.type_embedding.call()
# repinit
g1_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
(nframes, nall, self.tebd_dim),
)
g1_inp = g1_ext[:, :nloc, :]
Expand All @@ -825,6 +826,7 @@ def call(
atype_ext,
g1_ext,
mapping,
type_embedding=type_embedding,
)
if use_three_body:
assert self.repinit_three_body is not None
Expand All @@ -839,6 +841,7 @@ def call(
atype_ext,
g1_ext,
mapping,
type_embedding=type_embedding,
)
g1 = xp.concat([g1, g1_three_body], axis=-1)
# linear to change shape
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def call(
atype_ext: np.ndarray,
atype_embd_ext: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
type_embedding: Optional[np.ndarray] = None,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
):
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
Expand Down
77 changes: 62 additions & 15 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,10 @@ def call(
del mapping
nf, nloc, nnei = nlist.shape
nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3
type_embedding = self.type_embedding.call()
# nf x nall x tebd_dim
atype_embd_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
(nf, nall, self.tebd_dim),
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
# nfnl x tebd_dim
Expand All @@ -345,6 +346,7 @@ def call(
atype_ext,
atype_embd_ext,
mapping=None,
type_embedding=type_embedding,
)
# nf x nloc x (ng + tebd_dim)
if self.concat_output_tebd:
Expand Down Expand Up @@ -667,6 +669,7 @@ def call(
atype_ext: np.ndarray,
atype_embd_ext: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
type_embedding: Optional[np.ndarray] = None,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
):
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
# nf x nloc x nnei x 4
Expand Down Expand Up @@ -703,20 +706,26 @@ def call(
env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1)
# nfnl x nt_i x nt_j x 1
ss = env_ij[..., None]

nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist))
index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim))
# nfnl x nnei x tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
# nfnl x nt_i x nt_j x tebd_dim
nlist_tebd_i = xp.tile(atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1))
nlist_tebd_j = xp.tile(atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1))
ng = self.neuron[-1]
nt = self.tebd_dim

if self.tebd_input_mode in ["concat"]:
index = xp.tile(
xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)
)
# nfnl x nnei x tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
# nfnl x nt_i x nt_j x tebd_dim
nlist_tebd_i = xp.tile(
atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1)
)
nlist_tebd_j = xp.tile(
atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1)
)
# nfnl x nt_i x nt_j x (1 + tebd_dim * 2)
ss = xp.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1)
# nfnl x nt_i x nt_j x ng
Expand All @@ -725,10 +734,48 @@ def call(
# nfnl x nt_i x nt_j x ng
gg_s = self.cal_g(ss, 0)
assert self.embeddings_strip is not None
# nfnl x nt_i x nt_j x (tebd_dim * 2)
tt = xp.concat([nlist_tebd_i, nlist_tebd_j], axis=-1)
# nfnl x nt_i x nt_j x ng
gg_t = self.cal_g_strip(tt, 0)
assert type_embedding is not None
ntypes_with_padding = type_embedding.shape[0]
iProzd marked this conversation as resolved.
Show resolved Hide resolved
# nf x (nl x nnei)
nlist_index = xp.reshape(nlist_masked, (nf, nloc * nnei))
# nf x (nl x nnei)
nei_type = xp_take_along_axis(atype_ext, nlist_index, axis=1)
# nfnl x nnei
nei_type = xp.reshape(nei_type, (nf * nloc, nnei))

# nfnl x nnei x nnei
nei_type_i = xp.tile(nei_type[:, :, np.newaxis], (1, 1, nnei))
nei_type_j = xp.tile(nei_type[:, np.newaxis, :], (1, nnei, 1))

idx_i = nei_type_i * ntypes_with_padding
idx_j = nei_type_j

# (nf x nl x nt_i x nt_j) x ng
idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng))

# ntypes * (ntypes) * nt
type_embedding_i = xp.tile(
xp.reshape(type_embedding, (ntypes_with_padding, 1, nt)),
(1, ntypes_with_padding, 1),
)

# (ntypes) * ntypes * nt
type_embedding_j = xp.tile(
xp.reshape(type_embedding, (1, ntypes_with_padding, nt)),
(ntypes_with_padding, 1, 1),
)

# (ntypes * ntypes) * (nt+nt)
two_side_type_embedding = xp.reshape(
xp.concat([type_embedding_i, type_embedding_j], axis=-1), (-1, nt * 2)
)
tt_full = self.cal_g_strip(two_side_type_embedding, 0)

# (nfnl x nt_i x nt_j) x ng
gg_t = xp_take_along_axis(tt_full, idx, axis=0)

# (nfnl x nt_i x nt_j) x ng
gg_t = xp.reshape(gg_t, (nf * nloc, nnei, nnei, ng))
if self.smooth:
gg_t = (
gg_t
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
type_embedding: Optional[torch.Tensor] = None,
):
"""Calculate DescriptorBlock."""
pass
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,12 +687,17 @@ def forward(
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
if self.tebd_input_mode in ["strip"]:
type_embedding = self.type_embedding.get_full_embedding(g1_ext.device)
else:
type_embedding = None
g1, g2, h2, rot_mat, sw = self.se_atten(
nlist,
extended_coord,
extended_atype,
g1_ext,
mapping=None,
type_embedding=type_embedding,
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
Expand Down
9 changes: 8 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def init_subclass_params(sub_data, sub_class):

self.repinit_args = init_subclass_params(repinit, RepinitArgs)
self.repformer_args = init_subclass_params(repformer, RepformerArgs)
self.tebd_input_mode = self.repinit_args.tebd_input_mode

self.repinit = DescrptBlockSeAtten(
self.repinit_args.rcut,
Expand Down Expand Up @@ -765,6 +766,10 @@ def forward(
# repinit
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
if self.tebd_input_mode in ["strip"]:
type_embedding = self.type_embedding.get_full_embedding(g1_ext.device)
else:
type_embedding = None
g1, _, _, _, _ = self.repinit(
nlist_dict[
get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel())
Expand All @@ -773,6 +778,7 @@ def forward(
extended_atype,
g1_ext,
mapping,
type_embedding,
)
if use_three_body:
assert self.repinit_three_body is not None
Expand All @@ -787,6 +793,7 @@ def forward(
extended_atype,
g1_ext,
mapping,
type_embedding,
)
g1 = torch.cat([g1, g1_three_body], dim=-1)
# linear to change shape
Expand All @@ -813,7 +820,7 @@ def forward(
extended_atype,
g1,
mapping,
comm_dict,
comm_dict=comm_dict,
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
type_embedding: Optional[torch.Tensor] = None,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
):
if comm_dict is None:
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
type_embedding: Optional[torch.Tensor] = None,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
):
"""Calculate decoded embedding for each atom.
Expand Down
Loading