diff --git a/dptb/data/dataset/_default_dataset.py b/dptb/data/dataset/_default_dataset.py index 94d34b43..f804d284 100644 --- a/dptb/data/dataset/_default_dataset.py +++ b/dptb/data/dataset/_default_dataset.py @@ -422,6 +422,10 @@ def E3statistics(self, model: torch.nn.Module=None, decay=False): edge_scales = stats["edge"]["norm_ave"] edge_scales[:,scalar_mask] = stats["edge"]["scalar_std"] model.node_prediction_h.set_scale_shift(scales=node_scales, shifts=node_shifts) + + if decay: + edge_shifts = model.edge_prediction_h.fit_radialdpdt_shift(stats["edge"]["decay"], self.type_mapper) + edge_scales = None model.edge_prediction_h.set_scale_shift(scales=edge_scales, shifts=edge_shifts) return stats @@ -447,12 +451,14 @@ def _E3edgespecies_stat(self, typed_dataset, decay): # calculate norm & mean typed_norm = {} + typed_scalar = {} typed_norm_ave = torch.ones(len(idp.bond_to_type), idp.orbpair_irreps.num_irreps) typed_norm_std = torch.zeros(len(idp.bond_to_type), idp.orbpair_irreps.num_irreps) typed_scalar_ave = torch.ones(len(idp.bond_to_type), n_scalar) typed_scalar_std = torch.zeros(len(idp.bond_to_type), n_scalar) for bt, tp in idp.bond_to_type.items(): norms_per_irrep = [] + scalar_per_irrep = [] count_scalar = 0 for ir, s in enumerate(irrep_slices): sub_tensor = typed_hopping[bt][:, s] @@ -472,13 +478,17 @@ def _E3edgespecies_stat(self, typed_dataset, decay): norms = torch.ones_like(sub_tensor[:, 0]) if decay: - norms_per_irrep.append(norms) + if not torch.isnan(sub_tensor).all(): + if sub_tensor.shape[-1] == 1: # is scalar + scalar_per_irrep.append(sub_tensor.squeeze(-1)) + norms_per_irrep.append(norms) assert count_scalar <= n_scalar # shape of typed_norm: (n_irreps, n_edges) if decay: - typed_norm[bt] = torch.stack(norms_per_irrep) + typed_scalar[bt] = torch.stack(scalar_per_irrep) # [n_scalar, n_edge] + typed_norm[bt] = torch.stack(norms_per_irrep) # [n_irreps, n_edge] edge_stats = { "norm_ave": typed_norm_ave, @@ -495,11 +505,12 @@ def _E3edgespecies_stat(self, typed_dataset, decay): lengths_bt = typed_dataset["edge_lengths"][typed_dataset["edge_type"].flatten().eq(tp)] sorted_lengths, indices = lengths_bt.sort() # from small to large # sort the norms by irrep l - sorted_norms = typed_norm[bt][idp.orbpair_irreps.sort().inv, :] + sorted_norms = typed_norm[bt] # sort the norms by edge length sorted_norms = sorted_norms[:, indices] decay_bt["edge_length"] = sorted_lengths decay_bt["norm_decay"] = sorted_norms + decay_bt["scalar_decay"] = typed_scalar[bt][:, indices] decay[bt] = decay_bt edge_stats["decay"] = decay diff --git a/dptb/entrypoints/train.py b/dptb/entrypoints/train.py index d2bc3484..2ea194d8 100644 --- a/dptb/entrypoints/train.py +++ b/dptb/entrypoints/train.py @@ -192,7 +192,8 @@ def train( # build model will handle the init model cases where the model options provided is not equals to the ones in checkpoint. checkpoint = init_model if init_model else None model = build_model(checkpoint=checkpoint, model_options=jdata["model_options"], common_options=jdata["common_options"]) - train_datasets.E3statistics(model=model) + decay = jdata["model_options"].get("prediction", {}).get("decay", False) + train_datasets.E3statistics(model=model, decay=decay) trainer = Trainer( train_options=jdata["train_options"], common_options=jdata["common_options"], diff --git a/dptb/nn/deeptb.py b/dptb/nn/deeptb.py index 0b102dd6..9c548564 100644 --- a/dptb/nn/deeptb.py +++ b/dptb/nn/deeptb.py @@ -10,7 +10,7 @@ from dptb.nn.nnsk import NNSK from dptb.nn.dftbsk import DFTBSK from e3nn.o3 import Linear -from dptb.nn.rescale import E3PerSpeciesScaleShift, E3PerEdgeSpeciesScaleShift +from dptb.nn.rescale import E3PerSpeciesScaleShift, E3PerEdgeSpeciesScaleShift, E3PerEdgeSpeciesRadialDpdtScaleShift import logging log = logging.getLogger(__name__) @@ -180,18 +180,32 @@ def __init__( device=self.device, **prediction_copy, ) + + if prediction_copy.get("decay"): + self.edge_prediction_h = E3PerEdgeSpeciesRadialDpdtScaleShift( + field=AtomicDataDict.EDGE_FEATURES_KEY, + num_types=n_species, + irreps_in=self.embedding.out_edge_irreps, + out_field = AtomicDataDict.EDGE_FEATURES_KEY, + shifts=0., + scales=1., + dtype=self.dtype, + device=self.device, + **prediction_copy, + ) + else: + self.edge_prediction_h = E3PerEdgeSpeciesScaleShift( + field=AtomicDataDict.EDGE_FEATURES_KEY, + num_types=n_species, + irreps_in=self.embedding.out_edge_irreps, + out_field = AtomicDataDict.EDGE_FEATURES_KEY, + shifts=0., + scales=1., + dtype=self.dtype, + device=self.device, + **prediction_copy, + ) - self.edge_prediction_h = E3PerEdgeSpeciesScaleShift( - field=AtomicDataDict.EDGE_FEATURES_KEY, - num_types=n_species, - irreps_in=self.embedding.out_edge_irreps, - out_field = AtomicDataDict.EDGE_FEATURES_KEY, - shifts=0., - scales=1., - dtype=self.dtype, - device=self.device, - **prediction_copy, - ) if overlap: self.idp_sk = OrbitalMapper(self.idp.basis, method="sktb", device=self.device) diff --git a/dptb/nn/rescale.py b/dptb/nn/rescale.py index 168e8351..83e5e1d2 100644 --- a/dptb/nn/rescale.py +++ b/dptb/nn/rescale.py @@ -6,10 +6,14 @@ from typing import Optional, List, Union import torch.nn.functional from e3nn.o3 import Linear +from dptb.nn.sktb import HoppingFormula, bond_length_list from e3nn.util.jit import compile_mode from dptb.data import AtomicDataDict +from dptb.utils.constants import atomic_num_dict import e3nn.o3 as o3 +log = logging.getLogger(__name__) + class PerSpeciesScaleShift(torch.nn.Module): """Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters. @@ -523,4 +527,215 @@ def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]=None): else: x = x - return x \ No newline at end of file + return x + + +class E3PerEdgeSpeciesRadialDpdtScaleShift(torch.nn.Module): + """Sum edgewise energies. + + Includes optional per-species-pair edgewise energy scales. + """ + + field: str + out_field: str + scales_trainble: bool + shifts_trainable: bool + has_scales: bool + has_shifts: bool + + def __init__( + self, + field: str, + num_types: int, + irreps_in, + shifts: Optional[torch.Tensor], + scales: Optional[torch.Tensor], + out_field: Optional[str] = None, + scales_trainable: bool = False, + shifts_trainable: bool = False, + dtype: Union[str, torch.dtype] = torch.float32, + device: Union[str, torch.device] = torch.device("cpu"), + **kwargs, + ): + """Sum edges into nodes.""" + super(E3PerEdgeSpeciesRadialDpdtScaleShift, self).__init__() + self.num_types = num_types + self.field = field + self.out_field = f"shifted_{field}" if out_field is None else out_field + self.irreps_in = irreps_in + self.num_scalar = 0 + self.device = device + self.dtype = dtype + self.shift_index = [] + self.scale_index = [] + + start = 0 + start_scalar = 0 + for mul, ir in irreps_in: + if str(ir) == "0e": + self.num_scalar += mul + self.shift_index += list(range(start_scalar, start_scalar + mul)) + start_scalar += mul + else: + self.shift_index += [-1] * mul * ir.dim + + for _ in range(mul): + self.scale_index += [start] * ir.dim + start += 1 + + self.shift_index = torch.as_tensor(self.shift_index, dtype=torch.long, device=device) + self.scale_index = torch.as_tensor(self.scale_index, dtype=torch.long, device=device) + + self.has_shifts = shifts is not None + self.has_scales = scales is not None + if scales is not None: + scales = torch.as_tensor(scales, dtype=self.dtype, device=device) + if len(scales.reshape(-1)) == 1: + scales = scales * torch.ones(num_types*num_types, self.irreps_in.num_irreps, dtype=self.dtype, device=self.device) + assert scales.shape == (num_types*num_types, self.irreps_in.num_irreps), f"Invalid shape of scales {scales}" + self.scales_trainable = scales_trainable + if scales_trainable: + self.scales = torch.nn.Parameter(scales) + else: + self.register_buffer("scales", scales) + + if shifts is not None: + shifts = torch.as_tensor(shifts, dtype=self.dtype, device=device) + if len(shifts.reshape(-1)) == 1: + shifts = shifts * torch.ones(num_types*num_types, self.num_scalar, dtype=self.dtype, device=self.device) + assert shifts.shape == (num_types*num_types, self.num_scalar), f"Invalid shape of shifts {shifts}" + self.shifts_trainable = shifts_trainable + if shifts_trainable: + self.shifts = torch.nn.Parameter(shifts) + else: + self.register_buffer("shifts", shifts) + + self.r0 = [] # initilize r0 + + def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None): + self.has_scales = scales is not None or self.has_scales + if scales is not None: + assert scales.shape == (self.num_types*self.num_types, self.irreps_in.num_irreps), f"Invalid shape of scales {scales}" + if self.scales_trainable: + self.scales = torch.nn.Parameter(scales) + else: + self.register_buffer("scales", scales) + + self.has_shifts = shifts is not None or self.has_shifts + if shifts is not None: + assert shifts.shape == (self.num_types*self.num_types, self.num_scalar, 7), f"Invalid shape of shifts {shifts}" + if self.shifts_trainable: + self.shifts = torch.nn.Parameter(shifts) + else: + self.register_buffer("shifts", shifts) + + def fit_radialdpdt_shift(self, decay, idp): + shifts = torch.randn(self.num_types*self.num_types, self.num_scalar, 7, dtype=self.dtype, device=self.device) + shifts.requires_grad_() + optimizer = torch.optim.Adam([shifts], lr=0.01) + lrsch = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3000, threshold=1e-5, eps=1e-5, verbose=True) + bond_sym = list(decay.keys()) + bsz = 128 + + for sym in idp.type_names: + self.r0.append(bond_length_list[atomic_num_dict[sym]-1]) + self.r0 = torch.tensor(self.r0, device=self.device, dtype=self.dtype) + + #TODO: check wether exist some bond that does not have eneough values, this may appear in sparse dopping. + #TODO: check whether there is bond that does not cover the range bwtween equilirbium r0 to r_cut. This may appear in some hetrogenous system. + n_edge_length = [] + edge_lengths = {} + scalar_decays = {} + for bsym in decay: + n_edge_length.append(len(decay[bsym]["edge_length"])) + edge_lengths[bsym] = decay[bsym]["edge_length"].type(self.dtype).to(self.device) + scalar_decays[bsym] = decay[bsym]["scalar_decay"].type(self.dtype).to(self.device) + + + if min(n_edge_length) <= bsz: + log.warning("There exist edge that does not have enough values for fitting edge decaying behaviour, please use decay == False.") + + edge_number = idp._index_to_ZZ.T + for i in range(40000): + optimizer.zero_grad() + rs = [None] * len(bond_sym) + frs = [None] * len(bond_sym) + # construct the dataset + for bsym in decay: + bt = idp.bond_to_type[bsym] + random_index = torch.randint(0, len(edge_lengths[bsym]), (bsz,)) + rs[bt] = edge_lengths[bsym][random_index] + frs[bt] = scalar_decays[bsym][:,random_index].T # [bsz, n_scalar] + rs = torch.cat(rs, dim=0) + frs = torch.cat(frs, dim=0) + r0 = 0.5*bond_length_list.type(self.dtype).to(self.device)[edge_number-1].sum(0) + r0 = r0.unsqueeze(1).repeat(1, bsz).reshape(-1) + + paraArray=shifts.reshape(-1, 1, self.num_scalar, 7).repeat(1,bsz,1,1).reshape(-1, self.num_scalar, 7) + + fr_ = self.poly5pow( + rij=rs, + paraArray=paraArray, + r0 = r0, + ) + + loss = (fr_ - frs).pow(2).mean() + + log.info("Decaying function fitting Step {}, loss: {:.4f}, lr: {:.5f}".format(i, loss.item(), lrsch.get_last_lr()[0])) + loss.backward() + optimizer.step() + lrsch.step(loss.item()) + + return shifts.detach() + + + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + + if not (self.has_scales or self.has_shifts): + return data + + edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0] + + species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten() + edge_atom_type = data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[data[AtomicDataDict.EDGE_INDEX_KEY]] + in_field = data[self.field] + + assert len(in_field) == len( + edge_center + ), "in_field doesnt seem to have correct per-edge shape" + + if self.has_scales: + in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field + if self.has_shifts: + shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar, 7) + r0 = self.r0[edge_atom_type].sum(0) * 0.5 + shifts = self.poly5pow( + rij=data[AtomicDataDict.EDGE_LENGTH_KEY], + r0=r0, + paraArray=shifts + ) # [n_edge, n_scalar] + in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0] + + data[self.out_field] = in_field + + return data + + def poly5pow(self, rij, paraArray, r0:torch.Tensor): + """> This function calculates SK integrals without the environment dependence of the form of powerlaw + + $$ h(rij) = alpha_1 * (rij / r_ij0)^(lambda + alpha_2) $$ + """ + + #alpha1, alpha2, alpha3, alpha4 = paraArray[:, 0], paraArray[:, 1]**2, paraArray[:, 2]**2, paraArray[:, 3]**2 + alpha1, alpha2, alpha3, alpha4, alpha5, alpha6, alpha7 = paraArray[..., 0], paraArray[..., 1], paraArray[..., 2], paraArray[..., 3], paraArray[..., 4], paraArray[..., 5], paraArray[..., 6].abs() + #[N, n_op] + shape = [-1]+[1] * (len(alpha1.shape)-1) + # [-1, 1] + rij = rij.reshape(shape) + r0 = r0.reshape(shape) + + r0 = r0 / 1.8897259886 + + return (alpha1 + alpha2 * (rij-r0) + 0.5 * alpha3 * (rij - r0)**2 + 1/6 * alpha4 * (rij-r0)**3 + 1./24 * alpha5 * (rij-r0)**4 + 1./120 * alpha6 * (rij-r0)**5) * (r0/rij)**(1 + alpha7) + \ No newline at end of file diff --git a/dptb/nn/tensor_product.py b/dptb/nn/tensor_product.py index c36fc68a..0c76f708 100644 --- a/dptb/nn/tensor_product.py +++ b/dptb/nn/tensor_product.py @@ -7,7 +7,7 @@ -_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt")) +_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=True) def wigner_D(l, alpha, beta, gamma): if not l < len(_Jd): diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index f4fb1708..a1bd4082 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -605,6 +605,7 @@ def sktb_prediction(): def e3tb_prediction(): doc_scales_trainable = "whether to scale the trianing target." doc_shifts_trainable = "whether to shift the training target." + doc_decay = "whether the edge normalization takes into account the decaying behaviour of the edge irreps" doc_neurons = "neurons in the neural network." doc_activation = "activation function." doc_if_batch_normalized = "if to turn on batch normalization" @@ -612,6 +613,7 @@ def e3tb_prediction(): nn = [ Argument("scales_trainable", bool, optional=True, default=False, doc=doc_scales_trainable), Argument("shifts_trainable", bool, optional=True, default=False, doc=doc_shifts_trainable), + Argument("decay", bool, optional=True, default=False, doc=doc_decay), Argument("neurons", list, optional=True, default=None, doc=doc_neurons), Argument("activation", str, optional=True, default="tanh", doc=doc_activation), Argument("if_batch_normalized", bool, optional=True, default=False, doc=doc_if_batch_normalized),