Skip to content

Commit

Permalink
Implement final version of the physical basis (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Dec 12, 2023
1 parent 7e14342 commit aa4b789
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 263 deletions.
4 changes: 3 additions & 1 deletion examples/alchemical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch_spex.atomic_composition import AtomicComposition
from power_spectrum import PowerSpectrum
from torch_spex.normalize import get_average_number_of_neighbors, normalize_true, normalize_false
from torch_spex.normalize import get_2_mom

from typing import Dict
from metatensor.torch import TensorMap
Expand Down Expand Up @@ -85,7 +86,7 @@ def get_sse(first, second):
"mlp": True,
"type": "physical",
"scale": 3.0,
"E_max": 500,
"E_max": 350,
"normalize": True,
"cost_trade_off": False
}
Expand All @@ -112,6 +113,7 @@ def __init__(self, hypers, all_species, do_forces) -> None:
self.all_species = all_species
self.spherical_expansion_calculator = SphericalExpansion(hypers, all_species, device=device)
n_max = self.spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.n_max_l
print("Radial basis:", n_max)
l_max = len(n_max) - 1
n_feat = sum([n_max[l]**2 * n_pseudo**2 for l in range(l_max+1)])
self.ps_calculator = PowerSpectrum(l_max, all_species)
Expand Down
1 change: 1 addition & 0 deletions torch_spex/physical_LE/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .physical_LE import get_physical_le_spliner
Binary file added torch_spex/physical_LE/eigenvalues.npy
Binary file not shown.
Binary file added torch_spex/physical_LE/eigenvectors.npy
Binary file not shown.
124 changes: 124 additions & 0 deletions torch_spex/physical_LE/physical_LE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np
import os
import copy

from ..splines import generate_splines


# All these periodic functions are zeroed for the (unlikely) case where r > 10*r_0
# which is outside the domain where the eigenvalue equation was solved

def s(n, x):
return np.sin(np.pi*(n+1.0)*x/10.0)

def ds(n, x):
return np.pi*(n+1.0)*np.cos(np.pi*(n+1.0)*x/10.0)/10.0

def c(n, x):
return np.cos(np.pi*(n+0.5)*x/10.0)

def dc(n, x):
return -np.pi*(n+0.5)*np.sin(np.pi*(n+0.5)*x/10.0)/10.0


def get_physical_le_spliner(E_max, r_cut, normalize, device, dtype):

l_max = 50
n_max = 50
n_max_big = 200

a = 10.0 # by construction of the files

dir_path = os.path.dirname(os.path.realpath(__file__))

E_ln = np.load(
os.path.join(
dir_path,
"eigenvalues.npy"
)
)
eigenvectors = np.load(
os.path.join(
dir_path,
"eigenvectors.npy"
)
)

E_nl = E_ln.T
l_max_new = np.where(E_nl[0, :] <= E_max)[0][-1]
if l_max_new > l_max:
raise ValueError("l_max too large, try decreasing E_max")
else:
l_max = l_max_new

n_max_l = []
for l in range(l_max+1):
n_max_l.append(np.where(E_nl[:, l] <= E_max)[0][-1] + 1)
if n_max_l[0] > n_max:
raise ValueError("n_max too large, try decreasing E_max")

def function_for_splining(n, l, x):
ret = np.zeros_like(x)
for m in range(n_max_big):
ret += (eigenvectors[l][m, n]*c(m, x) if l%2 == 0 else eigenvectors[l][m, n]*s(m, x))
if normalize:
# normalize by square root of sphere volume, excluding sqrt(4pi) which is included in the SH
ret *= (
np.sqrt( (1/3)*r_cut**3 ) # formally correct value
* 0.45 # hardcoded empirical correction factor, TODO: automate
)
return ret

def function_for_splining_derivative(n, l, x):
ret = np.zeros_like(x)
for m in range(n_max_big):
ret += (eigenvectors[l][m, n]*dc(m, x) if l%2 == 0 else eigenvectors[l][m, n]*ds(m, x))
if normalize:
# normalize by square root of sphere volume, excluding sqrt(4pi) which is included in the SH
ret *= (
np.sqrt( (1/3)*r_cut**3 ) # formally correct value
* 0.45 # hardcoded empirical correction factor, TODO: automate
)
return ret

"""
import matplotlib.pyplot as plt
r = np.linspace(0.01, a-0.001, 1000)
l = 0
for n in range(n_max_l[l]):
plt.plot(r, function_for_splining(n, l, r), label=str(n))
plt.plot([0.0, a], [0.0, 0.0], "black")
plt.xlim(0.0, a)
plt.legend()
plt.savefig("radial-real.pdf")
"""

def index_to_nl(index, n_max_l):
# FIXME: should probably use cumsum
n = copy.deepcopy(index)
for l in range(l_max+1):
n -= n_max_l[l]
if n < 0: break
return n + n_max_l[l], l

def function_for_splining_index(index, r):
n, l = index_to_nl(index, n_max_l)
return function_for_splining(n, l, r)

def function_for_splining_index_derivative(index, r):
n, l = index_to_nl(index, n_max_l)
return function_for_splining_derivative(n, l, r)

spliner = generate_splines(
function_for_splining_index,
function_for_splining_index_derivative,
np.sum(n_max_l),
a,
requested_accuracy=1e-6,
dtype=dtype,
device=device
)
print("Number of spline points:", len(spliner.spline_positions))

n_max_l = [int(n_max) for n_max in n_max_l]
return n_max_l, spliner
Loading

0 comments on commit aa4b789

Please sign in to comment.