diff --git a/deepmd/dpmodel/descriptor/descriptor.py b/deepmd/dpmodel/descriptor/descriptor.py index 2c6e8fee57..443a2a66f1 100644 --- a/deepmd/dpmodel/descriptor/descriptor.py +++ b/deepmd/dpmodel/descriptor/descriptor.py @@ -123,6 +123,7 @@ def call( extended_atype: np.ndarray, extended_atype_embd: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, + type_embedding: Optional[np.ndarray] = None, ): """Calculate DescriptorBlock.""" pass diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 62ab2a5a9a..20a758b170 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -494,9 +494,10 @@ def call( xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) nf, nloc, nnei = nlist.shape nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3 + type_embedding = self.type_embedding.call() # nf x nall x tebd_dim atype_embd_ext = xp.reshape( - xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0), + xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0), (nf, nall, self.tebd_dim), ) # nfnl x tebd_dim @@ -507,6 +508,7 @@ def call( atype_ext, atype_embd_ext, mapping=None, + type_embedding=type_embedding, ) # nf x nloc x (ng x ng1 + tebd_dim) if self.concat_output_tebd: @@ -874,10 +876,6 @@ def cal_g_strip( embedding_idx, ): assert self.embeddings_strip is not None - xp = array_api_compat.array_namespace(ss) - nfnl, nnei = ss.shape[0:2] - shape2 = math.prod(ss.shape[2:]) - ss = xp.reshape(ss, (nfnl, nnei, shape2)) # nfnl x nnei x ng gg = self.embeddings_strip[embedding_idx].call(ss) return gg @@ -889,6 +887,7 @@ def call( atype_ext: np.ndarray, atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, + type_embedding: Optional[np.ndarray] = None, ): xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) # nf x nloc x nnei x 4 @@ -896,6 +895,7 @@ def call( coord_ext, atype_ext, nlist, self.mean, self.stddev ) nf, nloc, nnei, _ = dmatrix.shape + atype = atype_ext[:, :nloc] exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # nfnl x nnei exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) @@ -906,28 +906,33 @@ def call( dmatrix = xp.reshape(dmatrix, (nf * nloc, nnei, 4)) # nfnl x nnei x 1 sw = xp.reshape(sw, (nf * nloc, nnei, 1)) - # nfnl x tebd_dim - atype_embd = xp.reshape(atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim)) - # nfnl x nnei x tebd_dim - atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1)) # nfnl x nnei nlist_mask = nlist != -1 # nfnl x nnei x 1 sw = xp.where(nlist_mask[:, :, None], sw, xp.full_like(sw, 0.0)) nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist)) - index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)) - # nfnl x nnei x tebd_dim - atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1) - atype_embd_nlist = xp.reshape( - atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim) - ) ng = self.neuron[-1] + nt = self.tebd_dim # nfnl x nnei x 4 rr = xp.reshape(dmatrix, (nf * nloc, nnei, 4)) rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype) # nfnl x nnei x 1 ss = rr[..., 0:1] if self.tebd_input_mode in ["concat"]: + # nfnl x tebd_dim + atype_embd = xp.reshape( + atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim) + ) + # nfnl x nnei x tebd_dim + atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1)) + index = xp.tile( + xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim) + ) + # nfnl x nnei x tebd_dim + atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1) + atype_embd_nlist = xp.reshape( + atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim) + ) if not self.type_one_side: # nfnl x nnei x (1 + 2 * tebd_dim) ss = xp.concat([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) @@ -941,14 +946,48 @@ def call( # nfnl x nnei x ng gg_s = self.cal_g(ss, 0) assert self.embeddings_strip is not None - if not self.type_one_side: - # nfnl x nnei x (tebd_dim * 2) - tt = xp.concat([atype_embd_nlist, atype_embd_nnei], axis=-1) + assert type_embedding is not None + ntypes_with_padding = type_embedding.shape[0] + # nf x (nl x nnei) + nlist_index = xp.reshape(nlist_masked, (nf, nloc * nnei)) + # nf x (nl x nnei) + nei_type = xp_take_along_axis(atype_ext, nlist_index, axis=1) + # (nf x nl x nnei) x ng + nei_type_index = xp.tile(xp.reshape(nei_type, (-1, 1)), (1, ng)) + if self.type_one_side: + tt_full = self.cal_g_strip(type_embedding, 0) + # (nf x nl x nnei) x ng + gg_t = xp_take_along_axis(tt_full, nei_type_index, axis=0) else: - # nfnl x nnei x tebd_dim - tt = atype_embd_nlist - # nfnl x nnei x ng - gg_t = self.cal_g_strip(tt, 0) + idx_i = xp.reshape( + xp.tile( + (xp.reshape(atype, (-1, 1)) * ntypes_with_padding), (1, nnei) + ), + (-1), + ) + idx_j = xp.reshape(nei_type, (-1,)) + # (nf x nl x nnei) x ng + idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng)) + # (ntypes) * ntypes * nt + type_embedding_nei = xp.tile( + xp.reshape(type_embedding, (1, ntypes_with_padding, nt)), + (ntypes_with_padding, 1, 1), + ) + # ntypes * (ntypes) * nt + type_embedding_center = xp.tile( + xp.reshape(type_embedding, (ntypes_with_padding, 1, nt)), + (1, ntypes_with_padding, 1), + ) + # (ntypes * ntypes) * (nt+nt) + two_side_type_embedding = xp.reshape( + xp.concat([type_embedding_nei, type_embedding_center], axis=-1), + (-1, nt * 2), + ) + tt_full = self.cal_g_strip(two_side_type_embedding, 0) + # (nf x nl x nnei) x ng + gg_t = xp_take_along_axis(tt_full, idx, axis=0) + # (nf x nl) x nnei x ng + gg_t = xp.reshape(gg_t, (nf * nloc, nnei, ng)) if self.smooth: gg_t = gg_t * xp.reshape(sw, (-1, self.nnei, 1)) # nfnl x nnei x ng diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index eb6bfa4766..e4cadb7b36 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -811,9 +811,10 @@ def call( self.rcut_list, self.nsel_list, ) + type_embedding = self.type_embedding.call() # repinit g1_ext = xp.reshape( - xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0), + xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0), (nframes, nall, self.tebd_dim), ) g1_inp = g1_ext[:, :nloc, :] @@ -825,6 +826,7 @@ def call( atype_ext, g1_ext, mapping, + type_embedding=type_embedding, ) if use_three_body: assert self.repinit_three_body is not None @@ -839,6 +841,7 @@ def call( atype_ext, g1_ext, mapping, + type_embedding=type_embedding, ) g1 = xp.concat([g1, g1_three_body], axis=-1) # linear to change shape diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 34d9d8f6bc..ae6b5de511 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -389,6 +389,7 @@ def call( atype_ext: np.ndarray, atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, + type_embedding: Optional[np.ndarray] = None, ): xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 1efa991047..c350e3eb47 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -332,9 +332,10 @@ def call( del mapping nf, nloc, nnei = nlist.shape nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3 + type_embedding = self.type_embedding.call() # nf x nall x tebd_dim atype_embd_ext = xp.reshape( - xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0), + xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0), (nf, nall, self.tebd_dim), ) # nfnl x tebd_dim @@ -345,6 +346,7 @@ def call( atype_ext, atype_embd_ext, mapping=None, + type_embedding=type_embedding, ) # nf x nloc x (ng + tebd_dim) if self.concat_output_tebd: @@ -667,6 +669,7 @@ def call( atype_ext: np.ndarray, atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, + type_embedding: Optional[np.ndarray] = None, ): xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) # nf x nloc x nnei x 4 @@ -703,20 +706,26 @@ def call( env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1) # nfnl x nt_i x nt_j x 1 ss = env_ij[..., None] - nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist)) - index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)) - # nfnl x nnei x tebd_dim - atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1) - atype_embd_nlist = xp.reshape( - atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim) - ) - # nfnl x nt_i x nt_j x tebd_dim - nlist_tebd_i = xp.tile(atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1)) - nlist_tebd_j = xp.tile(atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1)) ng = self.neuron[-1] + nt = self.tebd_dim if self.tebd_input_mode in ["concat"]: + index = xp.tile( + xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim) + ) + # nfnl x nnei x tebd_dim + atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1) + atype_embd_nlist = xp.reshape( + atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim) + ) + # nfnl x nt_i x nt_j x tebd_dim + nlist_tebd_i = xp.tile( + atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1) + ) + nlist_tebd_j = xp.tile( + atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1) + ) # nfnl x nt_i x nt_j x (1 + tebd_dim * 2) ss = xp.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1) # nfnl x nt_i x nt_j x ng @@ -725,10 +734,48 @@ def call( # nfnl x nt_i x nt_j x ng gg_s = self.cal_g(ss, 0) assert self.embeddings_strip is not None - # nfnl x nt_i x nt_j x (tebd_dim * 2) - tt = xp.concat([nlist_tebd_i, nlist_tebd_j], axis=-1) - # nfnl x nt_i x nt_j x ng - gg_t = self.cal_g_strip(tt, 0) + assert type_embedding is not None + ntypes_with_padding = type_embedding.shape[0] + # nf x (nl x nnei) + nlist_index = xp.reshape(nlist_masked, (nf, nloc * nnei)) + # nf x (nl x nnei) + nei_type = xp_take_along_axis(atype_ext, nlist_index, axis=1) + # nfnl x nnei + nei_type = xp.reshape(nei_type, (nf * nloc, nnei)) + + # nfnl x nnei x nnei + nei_type_i = xp.tile(nei_type[:, :, np.newaxis], (1, 1, nnei)) + nei_type_j = xp.tile(nei_type[:, np.newaxis, :], (1, nnei, 1)) + + idx_i = nei_type_i * ntypes_with_padding + idx_j = nei_type_j + + # (nf x nl x nt_i x nt_j) x ng + idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng)) + + # ntypes * (ntypes) * nt + type_embedding_i = xp.tile( + xp.reshape(type_embedding, (ntypes_with_padding, 1, nt)), + (1, ntypes_with_padding, 1), + ) + + # (ntypes) * ntypes * nt + type_embedding_j = xp.tile( + xp.reshape(type_embedding, (1, ntypes_with_padding, nt)), + (ntypes_with_padding, 1, 1), + ) + + # (ntypes * ntypes) * (nt+nt) + two_side_type_embedding = xp.reshape( + xp.concat([type_embedding_i, type_embedding_j], axis=-1), (-1, nt * 2) + ) + tt_full = self.cal_g_strip(two_side_type_embedding, 0) + + # (nfnl x nt_i x nt_j) x ng + gg_t = xp_take_along_axis(tt_full, idx, axis=0) + + # (nfnl x nt_i x nt_j) x ng + gg_t = xp.reshape(gg_t, (nf * nloc, nnei, nnei, ng)) if self.smooth: gg_t = ( gg_t diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 3c738d59e3..d97f8964c4 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -174,6 +174,7 @@ def forward( extended_atype: torch.Tensor, extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, + type_embedding: Optional[torch.Tensor] = None, ): """Calculate DescriptorBlock.""" pass diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index fcf12be79c..ba2fd1b6c6 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -687,12 +687,17 @@ def forward( nall = extended_coord.view(nframes, -1).shape[1] // 3 g1_ext = self.type_embedding(extended_atype) g1_inp = g1_ext[:, :nloc, :] + if self.tebd_input_mode in ["strip"]: + type_embedding = self.type_embedding.get_full_embedding(g1_ext.device) + else: + type_embedding = None g1, g2, h2, rot_mat, sw = self.se_atten( nlist, extended_coord, extended_atype, g1_ext, mapping=None, + type_embedding=type_embedding, ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index b6c3c5bb4b..81918628a6 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -162,6 +162,7 @@ def init_subclass_params(sub_data, sub_class): self.repinit_args = init_subclass_params(repinit, RepinitArgs) self.repformer_args = init_subclass_params(repformer, RepformerArgs) + self.tebd_input_mode = self.repinit_args.tebd_input_mode self.repinit = DescrptBlockSeAtten( self.repinit_args.rcut, @@ -765,6 +766,10 @@ def forward( # repinit g1_ext = self.type_embedding(extended_atype) g1_inp = g1_ext[:, :nloc, :] + if self.tebd_input_mode in ["strip"]: + type_embedding = self.type_embedding.get_full_embedding(g1_ext.device) + else: + type_embedding = None g1, _, _, _, _ = self.repinit( nlist_dict[ get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel()) @@ -773,6 +778,7 @@ def forward( extended_atype, g1_ext, mapping, + type_embedding, ) if use_three_body: assert self.repinit_three_body is not None @@ -787,6 +793,7 @@ def forward( extended_atype, g1_ext, mapping, + type_embedding, ) g1 = torch.cat([g1, g1_three_body], dim=-1) # linear to change shape @@ -813,7 +820,7 @@ def forward( extended_atype, g1, mapping, - comm_dict, + comm_dict=comm_dict, ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 32295e1c1f..afc4ed4b92 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -389,6 +389,7 @@ def forward( extended_atype: torch.Tensor, extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, + type_embedding: Optional[torch.Tensor] = None, comm_dict: Optional[dict[str, torch.Tensor]] = None, ): if comm_dict is None: diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 54cf29151b..408cd51d3d 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -725,6 +725,7 @@ def forward( extended_atype: torch.Tensor, extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, + type_embedding: Optional[torch.Tensor] = None, ): """Calculate decoded embedding for each atom. diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index ea97a2f691..1ce6ad4583 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -447,6 +447,7 @@ def forward( extended_atype: torch.Tensor, extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, + type_embedding: Optional[torch.Tensor] = None, ): """Compute the descriptor. @@ -462,6 +463,9 @@ def forward( The extended type embedding of atoms. shape: nf x nall mapping The index mapping, not required by this descriptor. + type_embedding + Full type embeddings. shape: (ntypes+1) x nt + Required for stripped type embeddings. Returns ------- @@ -502,23 +506,12 @@ def forward( nlist_mask = nlist != -1 nlist = torch.where(nlist == -1, 0, nlist) sw = torch.squeeze(sw, -1) - # nf x nloc x nt -> nf x nloc x nnei x nt - atype_tebd = extended_atype_embd[:, :nloc, :] - atype_tebd_nnei = atype_tebd.unsqueeze(2).expand(-1, -1, self.nnei, -1) # i # nf x nall x nt nt = extended_atype_embd.shape[-1] - atype_tebd_ext = extended_atype_embd - # nb x (nloc x nnei) x nt - index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, nt) - # nb x (nloc x nnei) x nt - atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) # j - # nb x nloc x nnei x nt - atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) # beyond the cutoff sw should be 0.0 sw = sw.masked_fill(~nlist_mask, 0.0) # (nb x nloc) x nnei exclude_mask = exclude_mask.view(nb * nloc, nnei) - # nfnl x nnei x 4 dmatrix = dmatrix.view(-1, self.nnei, 4) nfnl = dmatrix.shape[0] @@ -526,9 +519,21 @@ def forward( rr = dmatrix rr = rr * exclude_mask[:, :, None] ss = rr[:, :, :1] - nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) - atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) if self.tebd_input_mode in ["concat"]: + atype_tebd_ext = extended_atype_embd + # nb x (nloc x nnei) x nt + index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, nt) + # nb x (nloc x nnei) x nt + atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) # j + # nb x nloc x nnei x nt + atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) + + # nf x nloc x nt -> nf x nloc x nnei x nt + atype_tebd = extended_atype_embd[:, :nloc, :] + atype_tebd_nnei = atype_tebd.unsqueeze(2).expand(-1, -1, self.nnei, -1) # i + + nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) + atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) if not self.type_one_side: # nfnl x nnei x (1 + tebd_dim * 2) ss = torch.concat([ss, nlist_tebd, atype_tebd], dim=2) @@ -546,23 +551,56 @@ def forward( # nfnl x 4 x ng xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) elif self.tebd_input_mode in ["strip"]: + assert self.filter_layers_strip is not None + assert type_embedding is not None + ng = self.filter_neuron[-1] + ntypes_with_padding = type_embedding.shape[0] + # nf x (nl x nnei) + nlist_index = nlist.reshape(nb, nloc * nnei) + # nf x (nl x nnei) + nei_type = torch.gather(extended_atype, dim=1, index=nlist_index) + # (nf x nl x nnei) x ng + nei_type_index = nei_type.view(-1, 1).expand(-1, ng).type(torch.long) + if self.type_one_side: + tt_full = self.filter_layers_strip.networks[0](type_embedding) + # (nf x nl x nnei) x ng + gg_t = torch.gather(tt_full, dim=0, index=nei_type_index) + else: + idx_i = torch.tile( + atype.reshape(-1, 1) * ntypes_with_padding, [1, nnei] + ).view(-1) + idx_j = nei_type.view(-1) + # (nf x nl x nnei) x ng + idx = ( + (idx_i + idx_j) + .view(-1, 1) + .expand(-1, ng) + .type(torch.long) + .to(torch.long) + ) + # (ntypes) * ntypes * nt + type_embedding_nei = torch.tile( + type_embedding.view(1, ntypes_with_padding, nt), + [ntypes_with_padding, 1, 1], + ) + # ntypes * (ntypes) * nt + type_embedding_center = torch.tile( + type_embedding.view(ntypes_with_padding, 1, nt), + [1, ntypes_with_padding, 1], + ) + # (ntypes * ntypes) * (nt+nt) + two_side_type_embedding = torch.cat( + [type_embedding_nei, type_embedding_center], -1 + ).reshape(-1, nt * 2) + tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding) + # (nf x nl x nnei) x ng + gg_t = torch.gather(tt_full, dim=0, index=idx) + # (nf x nl) x nnei x ng + gg_t = gg_t.reshape(nfnl, nnei, ng) + if self.smooth: + gg_t = gg_t * sw.reshape(-1, self.nnei, 1) if self.compress: ss = ss.reshape(-1, 1) - # nfnl x nnei x ng - # gg_s = self.filter_layers.networks[0](ss) - assert self.filter_layers_strip is not None - if not self.type_one_side: - # nfnl x nnei x (tebd_dim * 2) - tt = torch.concat([nlist_tebd, atype_tebd], dim=2) # dynamic, index - else: - # nfnl x nnei x tebd_dim - tt = nlist_tebd - # nfnl x nnei x ng - gg_t = self.filter_layers_strip.networks[0](tt) - if self.smooth: - gg_t = gg_t * sw.reshape(-1, self.nnei, 1) - # nfnl x nnei x ng - # gg = gg_s * gg_t + gg_s gg_t = gg_t.reshape(-1, gg_t.size(-1)) xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten( self.compress_data[0].contiguous(), @@ -585,17 +623,6 @@ def forward( else: # nfnl x nnei x ng gg_s = self.filter_layers.networks[0](ss) - assert self.filter_layers_strip is not None - if not self.type_one_side: - # nfnl x nnei x (tebd_dim * 2) - tt = torch.concat([nlist_tebd, atype_tebd], dim=2) # dynamic, index - else: - # nfnl x nnei x tebd_dim - tt = nlist_tebd - # nfnl x nnei x ng - gg_t = self.filter_layers_strip.networks[0](tt) - if self.smooth: - gg_t = gg_t * sw.reshape(-1, self.nnei, 1) # nfnl x nnei x ng gg = gg_s * gg_t + gg_s input_r = torch.nn.functional.normalize( diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 5cedf2c2ca..90565300dc 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -764,6 +764,7 @@ def forward( extended_atype: torch.Tensor, extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, + type_embedding: Optional[torch.Tensor] = None, ): """Compute the descriptor. @@ -779,6 +780,9 @@ def forward( The extended type embedding of atoms. shape: nf x nall mapping The index mapping, not required by this descriptor. + type_embedding + Full type embeddings. shape: (ntypes+1) x nt + Required for stripped type embeddings. Returns ------- diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index 069cee761d..01380a7fdf 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -175,6 +175,7 @@ def __init__( use_tebd_bias=use_tebd_bias, ) self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode self.concat_output_tebd = concat_output_tebd self.trainable = trainable # set trainable @@ -449,12 +450,17 @@ def forward( nall = extended_coord.view(nframes, -1).shape[1] // 3 g1_ext = self.type_embedding(extended_atype) g1_inp = g1_ext[:, :nloc, :] + if self.tebd_input_mode in ["strip"]: + type_embedding = self.type_embedding.get_full_embedding(g1_ext.device) + else: + type_embedding = None g1, _, _, _, sw = self.se_ttebd( nlist, extended_coord, extended_atype, g1_ext, mapping=None, + type_embedding=type_embedding, ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) @@ -732,6 +738,7 @@ def forward( extended_atype: torch.Tensor, extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, + type_embedding: Optional[torch.Tensor] = None, ): """Compute the descriptor. @@ -747,6 +754,9 @@ def forward( The extended type embedding of atoms. shape: nf x nall mapping The index mapping, not required by this descriptor. + type_embedding + Full type embeddings. shape: (ntypes+1) x nt + Required for stripped type embeddings. Returns ------- @@ -789,13 +799,6 @@ def forward( sw = torch.squeeze(sw, -1) # nf x nall x nt nt = extended_atype_embd.shape[-1] - atype_tebd_ext = extended_atype_embd - # nb x (nloc x nnei) x nt - index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, nt) - # nb x (nloc x nnei) x nt - atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) - # nb x nloc x nnei x nt - atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) # beyond the cutoff sw should be 0.0 sw = sw.masked_fill(~nlist_mask, 0.0) # (nb x nloc) x nnei @@ -816,15 +819,19 @@ def forward( env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j) # nfnl x nt_i x nt_j x 1 ss = env_ij.unsqueeze(-1) - - # nfnl x nnei x tebd_dim - nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) - - # nfnl x nt_i x nt_j x tebd_dim - nlist_tebd_i = nlist_tebd.unsqueeze(2).expand([-1, -1, self.nnei, -1]) - nlist_tebd_j = nlist_tebd.unsqueeze(1).expand([-1, self.nnei, -1, -1]) - if self.tebd_input_mode in ["concat"]: + atype_tebd_ext = extended_atype_embd + # nb x (nloc x nnei) x nt + index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, nt) + # nb x (nloc x nnei) x nt + atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) + # nb x nloc x nnei x nt + atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) + # nfnl x nnei x tebd_dim + nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) + # nfnl x nt_i x nt_j x tebd_dim + nlist_tebd_i = nlist_tebd.unsqueeze(2).expand([-1, -1, self.nnei, -1]) + nlist_tebd_j = nlist_tebd.unsqueeze(1).expand([-1, self.nnei, -1, -1]) # nfnl x nt_i x nt_j x (1 + tebd_dim * 2) ss = torch.concat([ss, nlist_tebd_i, nlist_tebd_j], dim=-1) # nfnl x nt_i x nt_j x ng @@ -833,10 +840,47 @@ def forward( # nfnl x nt_i x nt_j x ng gg_s = self.filter_layers.networks[0](ss) assert self.filter_layers_strip is not None - # nfnl x nt_i x nt_j x (tebd_dim * 2) - tt = torch.concat([nlist_tebd_i, nlist_tebd_j], dim=-1) - # nfnl x nt_i x nt_j x ng - gg_t = self.filter_layers_strip.networks[0](tt) + assert type_embedding is not None + ng = self.filter_neuron[-1] + ntypes_with_padding = type_embedding.shape[0] + # nf x (nl x nnei) + nlist_index = nlist.reshape(nb, nloc * nnei) + # nf x (nl x nnei) + nei_type = torch.gather(extended_atype, dim=1, index=nlist_index) + # nfnl x nnei + nei_type = nei_type.reshape(nfnl, nnei) + # nfnl x nnei x nnei + nei_type_i = nei_type.unsqueeze(2).expand([-1, -1, nnei]) + nei_type_j = nei_type.unsqueeze(1).expand([-1, nnei, -1]) + idx_i = nei_type_i * ntypes_with_padding + idx_j = nei_type_j + # (nf x nl x nt_i x nt_j) x ng + idx = ( + (idx_i + idx_j) + .view(-1, 1) + .expand(-1, ng) + .type(torch.long) + .to(torch.long) + ) + # ntypes * (ntypes) * nt + type_embedding_i = torch.tile( + type_embedding.view(ntypes_with_padding, 1, nt), + [1, ntypes_with_padding, 1], + ) + # (ntypes) * ntypes * nt + type_embedding_j = torch.tile( + type_embedding.view(1, ntypes_with_padding, nt), + [ntypes_with_padding, 1, 1], + ) + # (ntypes * ntypes) * (nt+nt) + two_side_type_embedding = torch.cat( + [type_embedding_i, type_embedding_j], -1 + ).reshape(-1, nt * 2) + tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding) + # (nfnl x nt_i x nt_j) x ng + gg_t = torch.gather(tt_full, dim=0, index=idx) + # (nfnl x nt_i x nt_j) x ng + gg_t = gg_t.reshape(nfnl, nnei, nnei, ng) if self.smooth: gg_t = ( gg_t diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index 1998cc0dce..353ed0c063 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -296,6 +296,23 @@ def forward(self, atype): """ return self.embedding(atype.device)[atype] + def get_full_embedding(self, device: torch.device): + """ + Get the type embeddings of all types. + + Parameters + ---------- + device : torch.device + The device on which to perform the computation. + + Returns + ------- + type_embedding : torch.Tensor + The full type embeddings of all types. The last index corresponds to the zero padding. + Shape: (ntypes + 1) x tebd_dim + """ + return self.embedding(device) + def share_params(self, base_class, shared_level, resume=False) -> None: """ Share the parameters of self to the base_class with shared_level during multitask training.