Skip to content

add Spline decay support for large off diagonal elements #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions dptb/data/dataset/_default_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ def E3statistics(self, model: torch.nn.Module=None, decay=False):
edge_scales = stats["edge"]["norm_ave"]
edge_scales[:,scalar_mask] = stats["edge"]["scalar_std"]
model.node_prediction_h.set_scale_shift(scales=node_scales, shifts=node_shifts)

if decay:
edge_shifts = model.edge_prediction_h.fit_radialdpdt_shift(stats["edge"]["decay"], self.type_mapper)
edge_scales = None
model.edge_prediction_h.set_scale_shift(scales=edge_scales, shifts=edge_shifts)

return stats
Expand All @@ -447,12 +451,14 @@ def _E3edgespecies_stat(self, typed_dataset, decay):

# calculate norm & mean
typed_norm = {}
typed_scalar = {}
typed_norm_ave = torch.ones(len(idp.bond_to_type), idp.orbpair_irreps.num_irreps)
typed_norm_std = torch.zeros(len(idp.bond_to_type), idp.orbpair_irreps.num_irreps)
typed_scalar_ave = torch.ones(len(idp.bond_to_type), n_scalar)
typed_scalar_std = torch.zeros(len(idp.bond_to_type), n_scalar)
for bt, tp in idp.bond_to_type.items():
norms_per_irrep = []
scalar_per_irrep = []
count_scalar = 0
for ir, s in enumerate(irrep_slices):
sub_tensor = typed_hopping[bt][:, s]
Expand All @@ -472,13 +478,17 @@ def _E3edgespecies_stat(self, typed_dataset, decay):
norms = torch.ones_like(sub_tensor[:, 0])

if decay:
norms_per_irrep.append(norms)
if not torch.isnan(sub_tensor).all():
if sub_tensor.shape[-1] == 1: # is scalar
scalar_per_irrep.append(sub_tensor.squeeze(-1))
norms_per_irrep.append(norms)

assert count_scalar <= n_scalar
# shape of typed_norm: (n_irreps, n_edges)

if decay:
typed_norm[bt] = torch.stack(norms_per_irrep)
typed_scalar[bt] = torch.stack(scalar_per_irrep) # [n_scalar, n_edge]
typed_norm[bt] = torch.stack(norms_per_irrep) # [n_irreps, n_edge]

edge_stats = {
"norm_ave": typed_norm_ave,
Expand All @@ -495,11 +505,12 @@ def _E3edgespecies_stat(self, typed_dataset, decay):
lengths_bt = typed_dataset["edge_lengths"][typed_dataset["edge_type"].flatten().eq(tp)]
sorted_lengths, indices = lengths_bt.sort() # from small to large
# sort the norms by irrep l
sorted_norms = typed_norm[bt][idp.orbpair_irreps.sort().inv, :]
sorted_norms = typed_norm[bt]
# sort the norms by edge length
sorted_norms = sorted_norms[:, indices]
decay_bt["edge_length"] = sorted_lengths
decay_bt["norm_decay"] = sorted_norms
decay_bt["scalar_decay"] = typed_scalar[bt][:, indices]
decay[bt] = decay_bt

edge_stats["decay"] = decay
Expand Down
3 changes: 2 additions & 1 deletion dptb/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def train(
# build model will handle the init model cases where the model options provided is not equals to the ones in checkpoint.
checkpoint = init_model if init_model else None
model = build_model(checkpoint=checkpoint, model_options=jdata["model_options"], common_options=jdata["common_options"])
train_datasets.E3statistics(model=model)
decay = jdata["model_options"].get("prediction", {}).get("decay", False)
train_datasets.E3statistics(model=model, decay=decay)
trainer = Trainer(
train_options=jdata["train_options"],
common_options=jdata["common_options"],
Expand Down
38 changes: 26 additions & 12 deletions dptb/nn/deeptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dptb.nn.nnsk import NNSK
from dptb.nn.dftbsk import DFTBSK
from e3nn.o3 import Linear
from dptb.nn.rescale import E3PerSpeciesScaleShift, E3PerEdgeSpeciesScaleShift
from dptb.nn.rescale import E3PerSpeciesScaleShift, E3PerEdgeSpeciesScaleShift, E3PerEdgeSpeciesRadialDpdtScaleShift
import logging

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -180,18 +180,32 @@ def __init__(
device=self.device,
**prediction_copy,
)

if prediction_copy.get("decay"):
self.edge_prediction_h = E3PerEdgeSpeciesRadialDpdtScaleShift(
field=AtomicDataDict.EDGE_FEATURES_KEY,
num_types=n_species,
irreps_in=self.embedding.out_edge_irreps,
out_field = AtomicDataDict.EDGE_FEATURES_KEY,
shifts=0.,
scales=1.,
dtype=self.dtype,
device=self.device,
**prediction_copy,
)
else:
self.edge_prediction_h = E3PerEdgeSpeciesScaleShift(
field=AtomicDataDict.EDGE_FEATURES_KEY,
num_types=n_species,
irreps_in=self.embedding.out_edge_irreps,
out_field = AtomicDataDict.EDGE_FEATURES_KEY,
shifts=0.,
scales=1.,
dtype=self.dtype,
device=self.device,
**prediction_copy,
)

self.edge_prediction_h = E3PerEdgeSpeciesScaleShift(
field=AtomicDataDict.EDGE_FEATURES_KEY,
num_types=n_species,
irreps_in=self.embedding.out_edge_irreps,
out_field = AtomicDataDict.EDGE_FEATURES_KEY,
shifts=0.,
scales=1.,
dtype=self.dtype,
device=self.device,
**prediction_copy,
)

if overlap:
self.idp_sk = OrbitalMapper(self.idp.basis, method="sktb", device=self.device)
Expand Down
217 changes: 216 additions & 1 deletion dptb/nn/rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
from typing import Optional, List, Union
import torch.nn.functional
from e3nn.o3 import Linear
from dptb.nn.sktb import HoppingFormula, bond_length_list
from e3nn.util.jit import compile_mode
from dptb.data import AtomicDataDict
from dptb.utils.constants import atomic_num_dict
import e3nn.o3 as o3

log = logging.getLogger(__name__)

class PerSpeciesScaleShift(torch.nn.Module):
"""Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters.

Expand Down Expand Up @@ -523,4 +527,215 @@ def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]=None):
else:
x = x

return x
return x


class E3PerEdgeSpeciesRadialDpdtScaleShift(torch.nn.Module):
"""Sum edgewise energies.

Includes optional per-species-pair edgewise energy scales.
"""

field: str
out_field: str
scales_trainble: bool
shifts_trainable: bool
has_scales: bool
has_shifts: bool

def __init__(
self,
field: str,
num_types: int,
irreps_in,
shifts: Optional[torch.Tensor],
scales: Optional[torch.Tensor],
out_field: Optional[str] = None,
scales_trainable: bool = False,
shifts_trainable: bool = False,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
**kwargs,
):
"""Sum edges into nodes."""
super(E3PerEdgeSpeciesRadialDpdtScaleShift, self).__init__()
self.num_types = num_types
self.field = field
self.out_field = f"shifted_{field}" if out_field is None else out_field
self.irreps_in = irreps_in
self.num_scalar = 0
self.device = device
self.dtype = dtype
self.shift_index = []
self.scale_index = []

start = 0
start_scalar = 0
for mul, ir in irreps_in:
if str(ir) == "0e":
self.num_scalar += mul
self.shift_index += list(range(start_scalar, start_scalar + mul))
start_scalar += mul
else:
self.shift_index += [-1] * mul * ir.dim

for _ in range(mul):
self.scale_index += [start] * ir.dim
start += 1

self.shift_index = torch.as_tensor(self.shift_index, dtype=torch.long, device=device)
self.scale_index = torch.as_tensor(self.scale_index, dtype=torch.long, device=device)

self.has_shifts = shifts is not None
self.has_scales = scales is not None
if scales is not None:
scales = torch.as_tensor(scales, dtype=self.dtype, device=device)
if len(scales.reshape(-1)) == 1:
scales = scales * torch.ones(num_types*num_types, self.irreps_in.num_irreps, dtype=self.dtype, device=self.device)
assert scales.shape == (num_types*num_types, self.irreps_in.num_irreps), f"Invalid shape of scales {scales}"
self.scales_trainable = scales_trainable
if scales_trainable:
self.scales = torch.nn.Parameter(scales)
else:
self.register_buffer("scales", scales)

if shifts is not None:
shifts = torch.as_tensor(shifts, dtype=self.dtype, device=device)
if len(shifts.reshape(-1)) == 1:
shifts = shifts * torch.ones(num_types*num_types, self.num_scalar, dtype=self.dtype, device=self.device)
assert shifts.shape == (num_types*num_types, self.num_scalar), f"Invalid shape of shifts {shifts}"
self.shifts_trainable = shifts_trainable
if shifts_trainable:
self.shifts = torch.nn.Parameter(shifts)
else:
self.register_buffer("shifts", shifts)

self.r0 = [] # initilize r0

def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None):
self.has_scales = scales is not None or self.has_scales
if scales is not None:
assert scales.shape == (self.num_types*self.num_types, self.irreps_in.num_irreps), f"Invalid shape of scales {scales}"
if self.scales_trainable:
self.scales = torch.nn.Parameter(scales)
else:
self.register_buffer("scales", scales)

self.has_shifts = shifts is not None or self.has_shifts
if shifts is not None:
assert shifts.shape == (self.num_types*self.num_types, self.num_scalar, 7), f"Invalid shape of shifts {shifts}"
if self.shifts_trainable:
self.shifts = torch.nn.Parameter(shifts)
else:
self.register_buffer("shifts", shifts)

def fit_radialdpdt_shift(self, decay, idp):
shifts = torch.randn(self.num_types*self.num_types, self.num_scalar, 7, dtype=self.dtype, device=self.device)
shifts.requires_grad_()
optimizer = torch.optim.Adam([shifts], lr=0.01)
lrsch = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3000, threshold=1e-5, eps=1e-5, verbose=True)
bond_sym = list(decay.keys())
bsz = 128

for sym in idp.type_names:
self.r0.append(bond_length_list[atomic_num_dict[sym]-1])
self.r0 = torch.tensor(self.r0, device=self.device, dtype=self.dtype)

#TODO: check wether exist some bond that does not have eneough values, this may appear in sparse dopping.
#TODO: check whether there is bond that does not cover the range bwtween equilirbium r0 to r_cut. This may appear in some hetrogenous system.
n_edge_length = []
edge_lengths = {}
scalar_decays = {}
for bsym in decay:
n_edge_length.append(len(decay[bsym]["edge_length"]))
edge_lengths[bsym] = decay[bsym]["edge_length"].type(self.dtype).to(self.device)
scalar_decays[bsym] = decay[bsym]["scalar_decay"].type(self.dtype).to(self.device)


if min(n_edge_length) <= bsz:
log.warning("There exist edge that does not have enough values for fitting edge decaying behaviour, please use decay == False.")

edge_number = idp._index_to_ZZ.T
for i in range(40000):
optimizer.zero_grad()
rs = [None] * len(bond_sym)
frs = [None] * len(bond_sym)
# construct the dataset
for bsym in decay:
bt = idp.bond_to_type[bsym]
random_index = torch.randint(0, len(edge_lengths[bsym]), (bsz,))
rs[bt] = edge_lengths[bsym][random_index]
frs[bt] = scalar_decays[bsym][:,random_index].T # [bsz, n_scalar]
rs = torch.cat(rs, dim=0)
frs = torch.cat(frs, dim=0)
r0 = 0.5*bond_length_list.type(self.dtype).to(self.device)[edge_number-1].sum(0)
r0 = r0.unsqueeze(1).repeat(1, bsz).reshape(-1)

paraArray=shifts.reshape(-1, 1, self.num_scalar, 7).repeat(1,bsz,1,1).reshape(-1, self.num_scalar, 7)

fr_ = self.poly5pow(
rij=rs,
paraArray=paraArray,
r0 = r0,
)

loss = (fr_ - frs).pow(2).mean()

log.info("Decaying function fitting Step {}, loss: {:.4f}, lr: {:.5f}".format(i, loss.item(), lrsch.get_last_lr()[0]))
loss.backward()
optimizer.step()
lrsch.step(loss.item())

return shifts.detach()



def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:

if not (self.has_scales or self.has_shifts):
return data

edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0]

species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten()
edge_atom_type = data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[data[AtomicDataDict.EDGE_INDEX_KEY]]
in_field = data[self.field]

assert len(in_field) == len(
edge_center
), "in_field doesnt seem to have correct per-edge shape"

if self.has_scales:
in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field
if self.has_shifts:
shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar, 7)
r0 = self.r0[edge_atom_type].sum(0) * 0.5
shifts = self.poly5pow(
rij=data[AtomicDataDict.EDGE_LENGTH_KEY],
r0=r0,
paraArray=shifts
) # [n_edge, n_scalar]
in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0]

data[self.out_field] = in_field

return data

def poly5pow(self, rij, paraArray, r0:torch.Tensor):
"""> This function calculates SK integrals without the environment dependence of the form of powerlaw

$$ h(rij) = alpha_1 * (rij / r_ij0)^(lambda + alpha_2) $$
"""

#alpha1, alpha2, alpha3, alpha4 = paraArray[:, 0], paraArray[:, 1]**2, paraArray[:, 2]**2, paraArray[:, 3]**2
alpha1, alpha2, alpha3, alpha4, alpha5, alpha6, alpha7 = paraArray[..., 0], paraArray[..., 1], paraArray[..., 2], paraArray[..., 3], paraArray[..., 4], paraArray[..., 5], paraArray[..., 6].abs()
#[N, n_op]
shape = [-1]+[1] * (len(alpha1.shape)-1)
# [-1, 1]
rij = rij.reshape(shape)
r0 = r0.reshape(shape)

r0 = r0 / 1.8897259886

return (alpha1 + alpha2 * (rij-r0) + 0.5 * alpha3 * (rij - r0)**2 + 1/6 * alpha4 * (rij-r0)**3 + 1./24 * alpha5 * (rij-r0)**4 + 1./120 * alpha6 * (rij-r0)**5) * (r0/rij)**(1 + alpha7)

2 changes: 1 addition & 1 deletion dptb/nn/tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@



_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=True)

def wigner_D(l, alpha, beta, gamma):
if not l < len(_Jd):
Expand Down
2 changes: 2 additions & 0 deletions dptb/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,13 +605,15 @@ def sktb_prediction():
def e3tb_prediction():
doc_scales_trainable = "whether to scale the trianing target."
doc_shifts_trainable = "whether to shift the training target."
doc_decay = "whether the edge normalization takes into account the decaying behaviour of the edge irreps"
doc_neurons = "neurons in the neural network."
doc_activation = "activation function."
doc_if_batch_normalized = "if to turn on batch normalization"

nn = [
Argument("scales_trainable", bool, optional=True, default=False, doc=doc_scales_trainable),
Argument("shifts_trainable", bool, optional=True, default=False, doc=doc_shifts_trainable),
Argument("decay", bool, optional=True, default=False, doc=doc_decay),
Argument("neurons", list, optional=True, default=None, doc=doc_neurons),
Argument("activation", str, optional=True, default="tanh", doc=doc_activation),
Argument("if_batch_normalized", bool, optional=True, default=False, doc=doc_if_batch_normalized),
Expand Down
Loading