Skip to content

Commit

Permalink
Modified Notebook according to PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
HowWeiBin committed Apr 9, 2024
1 parent 27cd109 commit 713c963
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 28 deletions.
1 change: 1 addition & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
- gaas-map
- batch-cp2k
- lpr
- dos-align

steps:
- uses: actions/checkout@v4
Expand Down
82 changes: 55 additions & 27 deletions examples/dos-align/dos-align.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


# 1) Construct the DOS using the original reference
# ------------------------------------------------------------------------------------------
# The DOS will first be constructed from the full set of eigenenergies to
# determine the Fermi level of each structure
# determine the Fermi level of each structure. The original reference is the
# Average Hartree Potential in this example.

# To ensure that all the eigenenergies are fully represented after
# gaussian broadening, the energy axis of the DOS extends
Expand All @@ -107,10 +109,10 @@
# Gaussian Smearing for the eDOS, 0.3eV is the appropriate value for this dataset

sigma = torch.tensor(0.3)

energy_interval = 0.05
# energy axis, with a grid interval of 0.05 eV

x_dos = torch.arange(energy_lower_bound, energy_upper_bound, 0.05)
x_dos = torch.arange(energy_lower_bound, energy_upper_bound, energy_interval)
print(
f"The energy axis ranges from {energy_lower_bound:.3} to \
{energy_upper_bound:.3}, consisting of {len(x_dos)} grid points"
Expand Down Expand Up @@ -140,8 +142,12 @@
print(f"The final shape of all the DOS in the dataset is: {list(total_edos.shape)}")

# %%
# Now we calculate the Fermi Level by integrating the DOS
# and then use cubic interpolation and brentq to determine the fermi level
# 2) Calculate the Fermi level from the DOS
# ------------------------------------------------------------------------------------------
# Now we integration the DOS, and then use cubic interpolation and brentq
# to calculate the fermi level. Since only the 4 valence electrons in Silicon
# are represented in this energy range, we take the point where the DOS integrates
# to 4 as the fermi level.

fermi_levels = []
total_i_edos = torch.cumulative_trapezoid(
Expand All @@ -154,9 +160,12 @@
Ef = brentq(
interpolated, x_dos[0] + 0.1, x_dos[-1] - 0.1
) # Fermi Level is the point where the (integrated DOS - 4) = 0
# 0.1 is added and subtracted to prevent brentq from going out of range
fermi_levels.append(Ef)
fermi_levels = torch.tensor(fermi_levels)
# %%
# 3) Build a set of eigenenergies, with the energy reference set to the fermi level
# ------------------------------------------------------------------------------------------
# Using the fermi levels, we are now able to change the energy reference
# of the eigenenergies to the fermi level

Expand All @@ -173,12 +182,25 @@


# %%
# 4) Truncate the DOS energy window so that the DOS is well-defined at each point
# ------------------------------------------------------------------------------------------
# With the fermi levels, we can also truncate the energy window for DOS prediction.
# In this example, we truncate the energy window such that it is 3eV above
# the highest Fermi level in the dataset
# the highest Fermi level in the dataset.

# For the Average Hartree Potential energy reference
x_dos_H = torch.arange(minE - 1.5, max(fermi_levels) + 3, energy_interval)

# For the Fermi Level Energy Reference, all the Fermi levels in the dataset is 0eV
x_dos_Ef = torch.arange(minE_Ef - 1.5, 3, energy_interval)
# %%
# 5) Construct the DOS in the truncated energy window under both references
# ------------------------------------------------------------------------------------------
# Here we construct 2 different targets where they differ in the energy reference
# chosen. These targets will then be treated as different datasets for the model
# to learn on.

# For the Average Hartree Potential energy reference
x_dos_H = torch.arange(minE - 1.5, max(fermi_levels) + 3, 0.05)

total_edos_H = []

Expand All @@ -195,8 +217,7 @@
total_edos_H = (total_edos_H.T * normalization).T


# For the Fermi Level Energy Reference, all the Fermi levels in the dataset is 0eV
x_dos_Ef = torch.arange(minE_Ef - 1.5, 3, 0.05)
# For the Fermi Level Energy Reference

total_edos_Ef = []

Expand All @@ -213,18 +234,15 @@
total_edos_Ef = (total_edos_Ef.T * normalization).T

# %%
# 6) Construct Splines for the DOS to facilitate interpolation during model training
# ------------------------------------------------------------------------------------------
# Building Cubic Hermite Splines on the DOS on the truncated energy window
# to facilitate interpolation during training. Cubic Hermite Splines takes
# in information on the value and derivative of a function at a point to build splines.
# Thus, we will have to compute both the value and derivative at each spline position

total_splines_H = []
# the splines have a higher energy range in case the shift is high
spline_positions_H = torch.arange(minE - 2, max(fermi_levels) + 6, 0.05)

# We need to compute the value and derivative of the DOS at each energy value, x


# Functions to compute the value and derivative of the DOS at each energy value, x
def edos_value(x, eigenenergies, normalization):
e_dos_E = (
torch.sum(
Expand All @@ -249,6 +267,10 @@ def edos_derivative(x, eigenenergies, normalization):
return dfn_dos_E


total_splines_H = []
# the splines have a higher energy range in case the shift is high
spline_positions_H = torch.arange(minE - 2, max(fermi_levels) + 6, energy_interval)

for index, structure_eigenenergies_H in enumerate(total_eigenenergies):
e_dos_H = edos_value(
spline_positions_H, structure_eigenenergies_H, normalization[index]
Expand All @@ -261,9 +283,8 @@ def edos_derivative(x, eigenenergies, normalization):

total_splines_H = torch.stack(total_splines_H)


total_splines_Ef = []
spline_positions_Ef = torch.arange(minE_Ef - 2, 6, 0.05)
spline_positions_Ef = torch.arange(minE_Ef - 2, 6, energy_interval)

for index, structure_eigenenergies_Ef in enumerate(total_eigenenergies_Ef):
e_dos_Ef = edos_value(
Expand Down Expand Up @@ -298,6 +319,7 @@ def evaluate_spline(spline_coefs, spline_positions, x):
x = torch.clamp(
x, min=spline_positions[0], max=spline_positions[-1] - 0.0005
) # restrict x to fall within the spline interval
# 0.0005 is substracted to combat errors arising from precision
indexes = torch.floor(
(x - spline_positions[0]) / interval
).long() # Obtain the index for the appropriate spline coefficients
Expand Down Expand Up @@ -375,9 +397,11 @@ def evaluate_spline(spline_coefs, spline_positions, x):

calculator = SoapPowerSpectrum(**HYPER_PARAMETERS)
R_total_soap = calculator.compute(structures)
# Transform the tensormap to a single block containing a dense representation
R_total_soap.keys_to_samples("species_center")
R_total_soap.keys_to_properties(["species_neighbor_1", "species_neighbor_2"])

# Now we extract the data tensor from the single block
total_atom_soap = []
for structure_i in range(n_structures):
a_i = R_total_soap.block(0).samples["structure"] == structure_i
Expand All @@ -390,13 +414,15 @@ def evaluate_spline(spline_coefs, spline_positions, x):
# ---------------------------------------------------------------------
#
# 1) Split the data into Training, Validation and Test
# 2) Build a dataloader
# 2) Define the dataloader and the Model Architecture
# 3) Define relevant loss functions for training and inference
# 4) Define the training loop
# 5) Evaluate the model
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# 1) Split the data into Training, Validation and Test
# ---------------------------------------------------------------------
# We will first split the data in a 7:1:2 manner, corresponding to train, val and test.
np.random.seed(0)
train_index = np.arange(n_structures)
Expand All @@ -408,7 +434,9 @@ def evaluate_spline(spline_coefs, spline_positions, x):
train_index = train_index[:val_mark]

# %%
# We will then build a dataloader to facillitate training the model batchwise
# 2) Define the dataloader and the Model Architecture
# ---------------------------------------------------------------------
# We will now build a dataloader and dataset to facillitate training the model batchwise


def generate_atomstructure_index(n_atoms_per_structure):
Expand Down Expand Up @@ -469,8 +497,7 @@ def __getitem__(self, idx):
def collate(
batch,
): # Defines how to collate the outputs of the __getitem__ function at each batch
for x, idx, index in batch:
return (x, idx, index)
return tuple(batch)


x_train = torch.flatten(total_soap[train_index], 0, 1).float()
Expand Down Expand Up @@ -510,7 +537,6 @@ def forward(self, x):
result = self.fc1(x)
result = self.silu(result)
result = self.fc2(result)
result = torch.exp(result)
return result


Expand Down Expand Up @@ -552,8 +578,10 @@ def forward(self, x):
# The alignment model takes the fermi level energy reference as the starting point

# %%
# We will now define some loss functions that will be useful
# when we implement the model training loop later
# 3) Define relevant loss functions for training and inference
# ---------------------------------------------------------------------
# We will now define some loss functions that will be useful when we implement
# the model training loop later and during model evaluation on the test set.


def t_get_mse(a, b, xdos):
Expand Down Expand Up @@ -671,7 +699,7 @@ def closure():

# %%
#
# Model Training Loop
# 4) Define the training loop
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We will now define the model training loop, for simplicity we will only
Expand Down Expand Up @@ -827,7 +855,7 @@ def train_model(model_to_train, fixed_DOS, structure_splines, spline_positions,

# %%
#
# Evaluation on Test Set
# 5) Evaluate the model
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We will now evaluate the model performance on the test set
# based on the model predictions we obtained previously
Expand Down
1 change: 0 additions & 1 deletion examples/dos-align/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ dependencies:
- pip:
- ase
- matplotlib
- metatensor
- rascaline @ git+https://github.com/Luthaf/rascaline@ca957642f512e141c7570e987aadc05c7ac71983
- torch
- scipy
Expand Down

0 comments on commit 713c963

Please sign in to comment.