Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling complex numbers for PairwisePotential #655

Merged
merged 31 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8a5bd40
Handling complex numbers for PairwisePotential
epolack Apr 29, 2022
281e161
comments from Antoine
epolack Apr 29, 2022
42b82c0
assert
epolack Apr 29, 2022
4198c1c
bug
epolack May 2, 2022
1b9a3c7
Update pairwise.jl
epolack May 3, 2022
20c8a17
workarounds for complex exponentiation
epolack May 3, 2022
ca7f803
Update pairwise.jl
epolack May 3, 2022
98b65a0
Update pairwise.jl
epolack May 5, 2022
4185299
Update pairwise.jl
epolack May 5, 2022
45da86c
Update pairwise.jl
epolack May 5, 2022
d4d48e9
testing ph_disp
epolack Jun 7, 2022
ef14519
comment
epolack Jun 7, 2022
9135a16
renaming
epolack Jun 7, 2022
3f343ea
factorisation
epolack Jun 7, 2022
0157e80
Move estimate_integer_bounds to structure.jl and support 1D and 2D sy…
niklasschmitz Jun 7, 2022
b4beeee
Rewrite pairwise without shelll_indices
niklasschmitz Jun 7, 2022
cbe2980
trim whitespace
niklasschmitz Jun 7, 2022
de7405c
Fix pairwise bound comments
niklasschmitz Jun 7, 2022
86141e0
Fix comment
niklasschmitz Jun 7, 2022
aa2bf98
first batch of modifications
epolack Jun 8, 2022
afc7c14
some more
epolack Jun 8, 2022
a558373
Merge branch 'nfs/pairwise-bounds' into complex_pairwise
epolack Jun 8, 2022
f88dd86
workaround back with a vengeance
epolack Jun 8, 2022
1d1d790
bugfix
epolack Jun 8, 2022
610d404
Update forwarddiff_rules.jl
epolack Jun 9, 2022
284a915
Antoine's comment
epolack Jun 13, 2022
82c3bda
Merge remote-tracking branch 'origin/master' into complex_pairwise
epolack Jun 13, 2022
74908cf
bugfix
epolack Jun 13, 2022
0927546
Update phonons.jl
epolack Jun 28, 2022
1c3d395
Update pairwise.jl
epolack Jun 28, 2022
9200787
Merge remote-tracking branch 'origin/master' into complex_pairwise
epolack Jun 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/structure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,15 @@ function diameter(lattice::AbstractMatrix)
end
diam
end

"""
Estimate integer bounds for dense space loops from a given inequality ||Mx|| ≤ δ.
For 1D and 2D systems the limit will be zero in the auxiliary dimensions.
"""
function estimate_integer_lattice_bounds(M, δ, shift=zeros(3))
# As a general statement, with M a lattice matrix, then if ||Mx|| <= δ,
# then xi = <ei, M^-1 Mx> = <M^-T ei, Mx> <= ||M^-T ei|| δ.
inv_lattice_t = compute_inverse_lattice(M')
xlims = [norm(inv_lattice_t[:, i]) * δ + shift[i] for i in 1:3]
ceil.(Int, xlims)
end
16 changes: 4 additions & 12 deletions src/terms/ewald.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,6 @@ function energy_ewald(lattice, charges, positions; η=nothing, forces=nothing)
energy_ewald(lattice, compute_recip_lattice(lattice), charges, positions; η, forces)
end

function estimate_integer_lattice_bounds(M, δ, shift=zeros(3))
# As a general statement, with M a lattice matrix, then if ||Mx|| <= δ,
# then xi = <ei, M^-1 Mx> = <M^-T ei, Mx> <= ||M^-T ei|| δ.
# Below code does not support non-3D systems.
xlims = [norm(inv(M')[:, i]) * δ + shift[i] for i in 1:3]
ceil.(Int, xlims)
end

# This could be factorised with Pairwise, but its use of `atom_types` would slow down this
# computationally intensive Ewald sums. So we leave it as it for now.
function energy_ewald(lattice, recip_lattice, charges, positions; η=nothing, forces=nothing)
Expand All @@ -91,17 +83,17 @@ function energy_ewald(lattice, recip_lattice, charges, positions; η=nothing, fo
max_erfc_arg = sqrt(max_exp_arg) # erfc(x) ~= exp(-x^2)/(sqrt(π)x) for large x

# Precomputing summation bounds from cutoffs.
# In the reciprocal-space term we have exp(-||B G||^2 / 4η^2),
# where B is the reciprocal-space lattice, and
# In the reciprocal-space term we have exp(-||B G||^2 / 4η^2),
# where B is the reciprocal-space lattice, and
# thus use the bound ||B G|| / 2η ≤ sqrt(max_exp_arg)
Glims = estimate_integer_lattice_bounds(recip_lattice, sqrt(max_exp_arg) * 2η)

# In the real-space term we have erfc(η ||A(rj - rk - R)||),
# In the real-space term we have erfc(η ||A(rj - rk - R)||),
# where A is the real-space lattice, rj and rk are atomic positions and
# thus use the bound ||A(rj - rk - R)|| * η ≤ max_erfc_arg
poslims = [maximum(rj[i] - rk[i] for rj in positions for rk in positions) for i in 1:3]
Rlims = estimate_integer_lattice_bounds(lattice, max_erfc_arg / η, poslims)

#
# Reciprocal space sum
#
Expand Down
120 changes: 68 additions & 52 deletions src/terms/pairwise.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
"""
Complex-analytic extension of `LinearAlgebra.norm(x)` from real to complex inputs.
Needed for phonons as we want to perform a matrix-vector product `f'(x)·h`, where `f` is
a real-to-real function and `h` a complex vector. To do this using automatic
differentiation, we can extend analytically f to accept complex inputs, then differentiate
`t -> f(x+t·h)`. This will fail if non-analytic functions like norm are used for complex
inputs, and therefore we have to redefine it.
"""
function norm_cplx(x)
sqrt(sum(x.^2))
end
epolack marked this conversation as resolved.
Show resolved Hide resolved

struct PairwisePotential
V
params
Expand All @@ -10,8 +22,8 @@ Lennard—Jones terms.
The potential is dependent on the distance between to atomic positions and the pairwise
atomic types:
For a distance `d` between to atoms `A` and `B`, the potential is `V(d, params[(A, B)])`.
The parameters `max_radius` is of `100` by default, and gives the maximum (reduced) distance
between nuclei for which we consider interactions.
The parameters `max_radius` is of `100` by default, and gives the maximum distance (in
Cartesian coordinates) between nuclei for which we consider interactions.
"""
function PairwisePotential(V, params; max_radius=100)
params = Dict(minmax(key[1], key[2]) => value for (key, value) in params)
Expand Down Expand Up @@ -54,74 +66,78 @@ function energy_pairwise(model::Model{T}, V, params; kwargs...) where {T}
end


# This could be factorised with Ewald, but the use of `symbols` would slow down the
# computationally intensive Ewald sums. So we leave it as it for now.
"""
This could be factorised with Ewald, but the use of `symbols` would slow down the
computationally intensive Ewald sums. So we leave it as it for now.
epolack marked this conversation as resolved.
Show resolved Hide resolved
`q` is the phonon `q`-point (`Vec3`), and `ph_disp` a list of `Vec3` displacements to
compute the Fourier transform of the force constant matrix. Only the computations of the
forces make sense.
For phonons computations, this gives the forces of particles `ti` in the unit cell w.r.t. to
a displacement of the particles `tj` of the form `ph_disp·e^{i q·R}`.
epolack marked this conversation as resolved.
Show resolved Hide resolved
"""
function energy_pairwise(lattice, symbols, positions, V, params;
max_radius=100, forces=nothing)
T = eltype(lattice)
max_radius=100, forces=nothing, ph_disp=nothing, q=nothing)
isnothing(ph_disp) && @assert isnothing(q)
@assert length(symbols) == length(positions)

T = eltype(positions[1])
if !isnothing(ph_disp)
@assert !isnothing(q) && !isnothing(forces)
T = promote_type(complex(T), eltype(ph_disp[1]))
@assert size(ph_disp) == size(positions)
end

if forces !== nothing
@assert size(forces) == size(positions)
forces_pairwise = copy(forces)
end

# Function to return the indices corresponding
# to a particular shell.
# Not performance critical, so we do not type the function
max_shell(n, trivial) = trivial ? 0 : n
# The potential V(dist) decays very quickly with dist = ||A (rj - rk - R)||,
# so we cut off at some point. We use the bound ||A (rj - rk - R)|| ≤ max_radius
# where A is the real-space lattice, rj and rk are atomic positions.
poslims = [maximum(rj[i] - rk[i] for rj in positions for rk in positions) for i in 1:3]
Rlims = estimate_integer_lattice_bounds(lattice, max_radius, poslims)

# Check if some coordinates are not used.
is_dim_trivial = [norm(lattice[:,i]) == 0 for i=1:3]
function shell_indices(nsh)
ish, jsh, ksh = max_shell.(nsh, is_dim_trivial)
[[i,j,k] for i in -ish:ish for j in -jsh:jsh for k in -ksh:ksh
if maximum(abs.([i,j,k])) == nsh]
end
max_shell(n, trivial) = trivial ? 0 : n
Rlims = max_shell.(Rlims, is_dim_trivial)

#
# Energy loop
#
sum_pairwise::T = zero(T)
# Loop over real-space shells
rsh = 0 # Include R = 0
any_term_contributes = true
while any_term_contributes || rsh <= 1
any_term_contributes = false

# Loop over R vectors for this shell patch
for R in shell_indices(rsh)
for i = 1:length(positions), j = 1:length(positions)
# Avoid self-interaction
rsh == 0 && i == j && continue

ti = positions[i]
tj = positions[j]
ai, aj = minmax(symbols[i], symbols[j])
param_ij =params[(ai, aj)]

Δr = lattice * (ti .- tj .- R)
dist = norm(Δr)

# the potential decays very quickly, so cut off at some point
dist > max_radius && continue

any_term_contributes = true
energy_contribution = V(dist, param_ij)
sum_pairwise += energy_contribution
if forces !== nothing
# We use ForwardDiff for the forces
dE_ddist = ForwardDiff.derivative(d -> V(d, param_ij), dist)
dE_dti = lattice' * dE_ddist / dist * Δr
forces_pairwise[i] -= dE_dti
forces_pairwise[j] += dE_dti
# Loop over real-space
for R1 in -Rlims[1]:Rlims[1], R2 in -Rlims[2]:Rlims[2], R3 in -Rlims[3]:Rlims[3]
R = Vec3(R1, R2, R3)
for i = 1:length(positions), j = 1:length(positions)
# Avoid self-interaction
R == zero(R) && i == j && continue
ai, aj = minmax(symbols[i], symbols[j])
param_ij = params[(ai, aj)]
ti = positions[i]
tj = positions[j] + R
if !isnothing(ph_disp)
ti += ph_disp[i] # * cis2pi(dot(q, zeros(3))) === 1
# as we use the forces at the nuclei in the unit cell
tj += ph_disp[j] * cis2pi(dot(q, R))
end
Δr = lattice * (ti .- tj)
dist = norm_cplx(Δr)
energy_contribution = V(dist, param_ij)
sum_pairwise += energy_contribution
if forces !== nothing
dE_ddist = ForwardDiff.derivative(real(zero(eltype(dist)))) do ε
V(dist + ε, param_ij)
end
end # i,j
end # R
rsh += 1
end
dE_dti = lattice' * dE_ddist / dist * Δr
forces_pairwise[i] -= dE_dti
end
end # i,j
end # R
energy = sum_pairwise / 2 # Divide by 2 (because of double counting)
if forces !== nothing
forces .= forces_pairwise ./ 2
forces .= forces_pairwise
epolack marked this conversation as resolved.
Show resolved Hide resolved
end
energy
end
41 changes: 41 additions & 0 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,44 @@ function Smearing.occupation(S::Smearing.FermiDirac, d::ForwardDiff.Dual{T}) whe
end
ForwardDiff.Dual{T}(Smearing.occupation(S, x), ∂occ * ForwardDiff.partials(d))
end

# Fix for https://github.com/JuliaDiff/ForwardDiff.jl/issues/514
function Base.:^(x::Complex{ForwardDiff.Dual{T,V,N}}, y::Complex{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
xx = complex(ForwardDiff.value(real(x)), ForwardDiff.value(imag(x)))
yy = complex(ForwardDiff.value(real(y)), ForwardDiff.value(imag(y)))
dx = complex.(ForwardDiff.partials(real(x)), ForwardDiff.partials(imag(x)))
dy = complex.(ForwardDiff.partials(real(y)), ForwardDiff.partials(imag(y)))

expv = xx^yy
∂expv∂x = yy * xx^(yy-1)
∂expv∂y = log(xx) * expv
dxexpv = ∂expv∂x * dx
if iszero(xx) && isconstant(real(y)) && isconstant(imag(y)) && imag(y) === zero(imag(y)) && real(y) > 0
dexpv = zero(expv)
elseif iszero(xx)
throw(DomainError(x, "mantissa cannot be zero for complex exponentiation"))
else
dyexpv = ∂expv∂y * dy
dexpv = dxexpv + dyexpv
end
complex(ForwardDiff.Dual{T,V,N}(real(expv), ForwardDiff.Partials{N,V}(tuple(real(dexpv)...))),
ForwardDiff.Dual{T,V,N}(imag(expv), ForwardDiff.Partials{N,V}(tuple(imag(dexpv)...))))
end
function Base.:^(x::Complex{ForwardDiff.Dual{T,V,N}}, y::Int64) where {T,V,N}
xx = complex(ForwardDiff.value(real(x)), ForwardDiff.value(imag(x)))
dx = complex.(ForwardDiff.partials(real(x)), ForwardDiff.partials(imag(x)))

expv = xx^y
∂expv∂x = y * xx^(y-1)
dxexpv = ∂expv∂x * dx
if iszero(xx) && imag(y) === zero(imag(y)) && real(y) > 0
dexpv = zero(expv)
elseif iszero(xx)
throw(DomainError(x, "mantissa cannot be zero for complex exponentiation"))
else
dexpv = dxexpv
end
complex(ForwardDiff.Dual{T,V,N}(real(expv), ForwardDiff.Partials{N,V}(tuple(real(dexpv)...))),
ForwardDiff.Dual{T,V,N}(imag(expv), ForwardDiff.Partials{N,V}(tuple(imag(dexpv)...))))
end

113 changes: 113 additions & 0 deletions test/phonons.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
using Test
epolack marked this conversation as resolved.
Show resolved Hide resolved
using DFTK
using LinearAlgebra
using ForwardDiff
using StaticArrays

# Convert back and forth between Vec3 and columnwise matrix
fold(x) = hcat(x...)
unfold(x) = Vec3.(eachcol(x))

function prepare_system(; n_scell=1)
positions = [[0.,0,0]]
for i in 1:n_scell-1
push!(positions, i*ones(3)/n_scell)
end

a = 5. * length(positions)
lattice = a * [[1 0 0.]; [0 0 0.]; [0 0 0.]]

s = DFTK.compute_inverse_lattice(lattice)
directions = [[s * [i==j,0,0] for i in 1:n_scell] for j in 1:n_scell]

params = Dict((:X, :X) => (; ε=1, σ=a / length(positions) /2^(1/6)))
V(x, p) = 4*p.ε * ((p.σ/x)^12 - (p.σ/x)^6)

(positions=positions, lattice=lattice, directions=directions, params=params, V=V)
end

# Compute phonons for a one-dimensional pairwise potential for a set of `q = 0` using
# supercell method
epolack marked this conversation as resolved.
Show resolved Hide resolved
function test_supercell_q0(; n_scell=1, max_radius=1e3)
blob = prepare_system(; n_scell)
positions = blob.positions
lattice = blob.lattice
directions = blob.directions
params = blob.params
V = blob.V

s = DFTK.compute_inverse_lattice(lattice)
n_atoms = length(positions)

directions = [reshape(vcat([[i==j, 0.0, 0.0] for i in 1:n_atoms]...), 3, :) for j in 1:n_atoms]

Φ = Array{eltype(positions[1])}(undef, length(directions), n_atoms)
for (i, direction) in enumerate(directions)
Φ[i, :] = - ForwardDiff.derivative(0.0) do ε
new_positions = unfold(fold(positions) .+ ε .* s * direction)
forces = zeros(Vec3{complex(eltype(ε))}, length(positions))
DFTK.energy_pairwise(lattice, [:X for _ in positions], new_positions, V, params;
forces, max_radius)
[(s * f)[1] for f in forces]
end
end
sqrt.(abs.(eigvals(Φ)))
end

# Compute phonons for a one-dimensional pairwise potential for a set of `q`-points
function test_ph_disp(; n_scell=1, max_radius=1e3, n_points=2)
blob = prepare_system(; n_scell)
positions = blob.positions
lattice = blob.lattice
directions = blob.directions
params = blob.params
V = blob.V

pairwise_ph = (q, d; forces=nothing) ->
DFTK.energy_pairwise(lattice, [:X for _ in positions],
positions, V, params; q=[q, 0, 0],
ph_disp=d, forces, max_radius)

ph_bands = []
qs = -1/2:1/n_points:1/2
for q in qs
as = ComplexF64[]
for d in directions
res = -ForwardDiff.derivative(0.0) do ε
forces = zeros(Vec3{complex(eltype(ε))}, length(positions))
pairwise_ph(q, ε*d; forces)
[DFTK.compute_inverse_lattice(lattice)' * f for f in forces]
end
[push!(as, r[1]) for r in res]
end
M = reshape(as, length(positions), :)
@assert ≈(norm(imag.(eigvals(M))), 0.0, atol=1e-8)
push!(ph_bands, sqrt.(abs.(real(eigvals(M)))))
end
return ph_bands
end

@testset "Phonon consistency" begin
max_radius = 1e3
tolerance = 1e-6
n_points = 10

ph_bands = []
for n_scell in [1, 2, 3]
push!(ph_bands, test_ph_disp(; n_scell, max_radius, n_points))
end

# Recover the same extremum for the system whatever case we test
for n_scell in [2, 3]
@test ≈(minimum(fold(ph_bands[1])), minimum(fold(ph_bands[n_scell])), atol=tolerance)
@test ≈(maximum(fold(ph_bands[1])), maximum(fold(ph_bands[n_scell])), atol=tolerance)
end

# Test consistency between supercell method at `q = 0` and direct `q`-points computations
for n_scell in [1, 2, 3]
r_q0 = test_supercell_q0(; n_scell, max_radius)
@assert length(r_q0) == n_scell
ph_band_q0 = ph_bands[n_scell][n_points÷2+1]
@test norm(r_q0 - ph_band_q0) < tolerance
end
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,10 @@ Random.seed!(0)
include("forwarddiff.jl")
end

# Phonons
if "all" in TAGS
include("phonons.jl")
end

("example" in TAGS) && include("runexamples.jl")
end