Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling complex numbers for PairwisePotential #655

Merged
merged 31 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8a5bd40
Handling complex numbers for PairwisePotential
epolack Apr 29, 2022
281e161
comments from Antoine
epolack Apr 29, 2022
42b82c0
assert
epolack Apr 29, 2022
4198c1c
bug
epolack May 2, 2022
1b9a3c7
Update pairwise.jl
epolack May 3, 2022
20c8a17
workarounds for complex exponentiation
epolack May 3, 2022
ca7f803
Update pairwise.jl
epolack May 3, 2022
98b65a0
Update pairwise.jl
epolack May 5, 2022
4185299
Update pairwise.jl
epolack May 5, 2022
45da86c
Update pairwise.jl
epolack May 5, 2022
d4d48e9
testing ph_disp
epolack Jun 7, 2022
ef14519
comment
epolack Jun 7, 2022
9135a16
renaming
epolack Jun 7, 2022
3f343ea
factorisation
epolack Jun 7, 2022
0157e80
Move estimate_integer_bounds to structure.jl and support 1D and 2D sy…
niklasschmitz Jun 7, 2022
b4beeee
Rewrite pairwise without shelll_indices
niklasschmitz Jun 7, 2022
cbe2980
trim whitespace
niklasschmitz Jun 7, 2022
de7405c
Fix pairwise bound comments
niklasschmitz Jun 7, 2022
86141e0
Fix comment
niklasschmitz Jun 7, 2022
aa2bf98
first batch of modifications
epolack Jun 8, 2022
afc7c14
some more
epolack Jun 8, 2022
a558373
Merge branch 'nfs/pairwise-bounds' into complex_pairwise
epolack Jun 8, 2022
f88dd86
workaround back with a vengeance
epolack Jun 8, 2022
1d1d790
bugfix
epolack Jun 8, 2022
610d404
Update forwarddiff_rules.jl
epolack Jun 9, 2022
284a915
Antoine's comment
epolack Jun 13, 2022
82c3bda
Merge remote-tracking branch 'origin/master' into complex_pairwise
epolack Jun 13, 2022
74908cf
bugfix
epolack Jun 13, 2022
0927546
Update phonons.jl
epolack Jun 28, 2022
1c3d395
Update pairwise.jl
epolack Jun 28, 2022
9200787
Merge remote-tracking branch 'origin/master' into complex_pairwise
epolack Jun 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions src/terms/pairwise.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# We cannot use `LinearAlgebra.norm` with complex numbers due to the need to use its
# analytic continuation
epolack marked this conversation as resolved.
Show resolved Hide resolved
function norm_cplx(x)
sqrt(sum(x.^2))
end
epolack marked this conversation as resolved.
Show resolved Hide resolved

struct PairwisePotential
V
params
Expand All @@ -10,8 +16,8 @@ Lennard—Jones terms.
The potential is dependent on the distance between to atomic positions and the pairwise
atomic types:
For a distance `d` between to atoms `A` and `B`, the potential is `V(d, params[(A, B)])`.
The parameters `max_radius` is of `100` by default, and gives the maximum (reduced) distance
between nuclei for which we consider interactions.
The parameters `max_radius` is of `100` by default, and gives the maximum distance (in
Cartesian coordinates) between nuclei for which we consider interactions.
"""
function PairwisePotential(V, params; max_radius=100)
params = Dict(minmax(key[1], key[2]) => value for (key, value) in params)
Expand All @@ -37,7 +43,8 @@ end
basis::PlaneWaveBasis{T}, ψ, occ;
kwargs...) where {T}
forces = zero(basis.model.positions)
energy_pairwise(basis.model, term.V, term.params; term.max_radius, forces)
energy_pairwise(basis.model, term.V, term.params; max_radius=term.max_radius,
epolack marked this conversation as resolved.
Show resolved Hide resolved
forces=forces)
forces
end

Expand All @@ -56,11 +63,19 @@ end

# This could be factorised with Ewald, but the use of `symbols` would slow down the
# computationally intensive Ewald sums. So we leave it as it for now.
# `q` is the phonon `q`-point (`Vec3`), and `ph_disp` a list of `Vec3` displacements to
epolack marked this conversation as resolved.
Show resolved Hide resolved
# compute the Fourier transform of the force constant matrix.
function energy_pairwise(lattice, symbols, positions, V, params;
max_radius=100, forces=nothing)
T = eltype(lattice)
max_radius=100, forces=nothing, ph_disp=nothing, q=nothing)
@assert length(symbols) == length(positions)

T = eltype(positions[1])
if ph_disp !== nothing
@assert q !== nothing
epolack marked this conversation as resolved.
Show resolved Hide resolved
T = promote_type(complex(T), eltype(ph_disp[1]))
@assert size(ph_disp) == size(positions)
end

if forces !== nothing
@assert size(forces) == size(positions)
forces_pairwise = copy(forces)
Expand Down Expand Up @@ -95,33 +110,40 @@ function energy_pairwise(lattice, symbols, positions, V, params;
rsh == 0 && i == j && continue

ti = positions[i]
tj = positions[j]
tj = positions[j] + R
# Phonons `q` points
epolack marked this conversation as resolved.
Show resolved Hide resolved
if !isnothing(ph_disp)
ti += ph_disp[i] # * cis(2T(π)*dot(q, zeros(3))) ≡ 1
epolack marked this conversation as resolved.
Show resolved Hide resolved
# as we use the forces at the nuclei in the unit cell
tj += ph_disp[j] * cis(2T(π)*dot(q, R))
end
ai, aj = minmax(symbols[i], symbols[j])
param_ij =params[(ai, aj)]

Δr = lattice * (ti .- tj .- R)
dist = norm(Δr)
Δr = lattice * (ti .- tj)
dist = norm_cplx(Δr)

# the potential decays very quickly, so cut off at some point
dist > max_radius && continue
abs(dist) > max_radius && continue

any_term_contributes = true
energy_contribution = V(dist, param_ij)
sum_pairwise += energy_contribution
if forces !== nothing
# We use ForwardDiff for the forces
dE_ddist = ForwardDiff.derivative(d -> V(d, param_ij), dist)
dE_dti = lattice' * dE_ddist / dist * Δr
dE_ddist = ForwardDiff.derivative(real(zero(eltype(dist)))) do ε
epolack marked this conversation as resolved.
Show resolved Hide resolved
V(dist + ε, param_ij)
end
dE_dti = lattice' * ((dE_ddist / dist) * Δr)
# For the phonons, we compute the forces only in the unit cell.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also quite obscure...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

forces_pairwise[i] -= dE_dti
forces_pairwise[j] += dE_dti
end
end # i,j
end # R
rsh += 1
end
energy = sum_pairwise / 2 # Divide by 2 (because of double counting)
if forces !== nothing
forces .= forces_pairwise ./ 2
forces .= forces_pairwise
epolack marked this conversation as resolved.
Show resolved Hide resolved
end
energy
end
29 changes: 29 additions & 0 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,32 @@ function Smearing.occupation(S::Smearing.FermiDirac, d::ForwardDiff.Dual{T}) whe
end
ForwardDiff.Dual{T}(Smearing.occupation(S, x), ∂occ * ForwardDiff.partials(d))
end

# Workarounds for issue https://github.com/JuliaDiff/ForwardDiff.jl/issues/324
ForwardDiff.derivative(f, x::Complex) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number (does not support Wirtinger derivatives). Separate real and imaginary parts of the input."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fixes are not for the issue above. Is the fix not already upstream in a released version? If yes, only define the functions if the version is below the one that has the fix (and then at some point we remove the code and depend on a new version)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This JuliaDiff/ForwardDiff.jl#577 seems to fix the exponentiation problem. So maybe remove it altogether?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it fix JuliaDiff/ForwardDiff.jl#514 (comment)? Since my message was after that PR was merge I don't think so?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I thought is was after. This is strange, something seems to have fixed the issue, because I am not running into it anymore, even for this test:

using ForwardDiff, FiniteDifferences, Random, Test

v = randn()
p, m = randn(ComplexF64), randn(ComplexF64)

for f in (x -> (x*m)^p,
          x -> m^(p*x),
          x -> (x*m)^(p*x))
  @test ≈(ForwardDiff.derivative(f, v), central_fdm(5,1)(f, v), atol=1e-8)
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the code in my comment in the forwarddiff tracker above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the issue was very specific: when differentiating at a real number in a complex direction.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, didn't test the right thing 🙄.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, the bug was due to something we talked about earlier if the types of x and y differ. Let's keep it that way.

@inline ForwardDiff.extract_derivative(::Type{T}, y::Complex) where {T} = zero(y)
@inline function ForwardDiff.extract_derivative(::Type{T}, y::Complex{TD}) where {T, TD <: ForwardDiff.Dual}
complex(ForwardDiff.partials(T, real(y), 1), ForwardDiff.partials(T, imag(y), 1))
end
function Base.:^(x::Complex{ForwardDiff.Dual{T,V,N}}, y::Complex{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
xx = complex(ForwardDiff.value(real(x)), ForwardDiff.value(imag(x)))
yy = complex(ForwardDiff.value(real(y)), ForwardDiff.value(imag(y)))
dx = complex.(ForwardDiff.partials(real(x)), ForwardDiff.partials(imag(x)))
dy = complex.(ForwardDiff.partials(real(y)), ForwardDiff.partials(imag(y)))

expv = xx^yy
∂expv∂x = yy * xx^(yy-1)
∂expv∂y = log(xx) * expv
dxexpv = ∂expv∂x * dx
# TODO: Fishy and should be checked, but seems to catch most cases
if iszero(xx) && ForwardDiff.isconstant(real(y)) && ForwardDiff.isconstant(imag(y)) && imag(y) === zero(imag(y)) && real(y) > 0
dexpv = zero(expv)
elseif iszero(xx)
throw(DomainError(x, "mantissa cannot be zero for complex exponentiation"))
else
dyexpv = ∂expv∂y * dy
dexpv = dxexpv + dyexpv
end
complex(ForwardDiff.Dual{T,V,N}(real(expv), ForwardDiff.Partials{N,V}(tuple(real(dexpv)...))),
ForwardDiff.Dual{T,V,N}(imag(expv), ForwardDiff.Partials{N,V}(tuple(imag(dexpv)...))))
end
124 changes: 124 additions & 0 deletions test/phonons.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
using Test
epolack marked this conversation as resolved.
Show resolved Hide resolved
using DFTK
using LinearAlgebra
using ForwardDiff
using StaticArrays

# ## Helper functions
epolack marked this conversation as resolved.
Show resolved Hide resolved
# Some functions that will be helpful for this example.
fold(x) = hcat(x...)
unfold(x) = Vec3.(eachcol(x))

const MAX_RADIUS = 1e3
epolack marked this conversation as resolved.
Show resolved Hide resolved
const TOLERANCE = 1e-6
const N_POINTS = 10

function prepare_system(; case=1)
epolack marked this conversation as resolved.
Show resolved Hide resolved
positions = [[0.,0,0]]
for i in 1:case-1
push!(positions, i*ones(3)/case)
end

a = 5. * length(positions)
lattice = a * [[1 0 0.]; [0 0 0.]; [0 0 0.]]

s = DFTK.compute_inverse_lattice(lattice)
if case === 1
epolack marked this conversation as resolved.
Show resolved Hide resolved
directions = [[s * [1,0,0]]]
elseif case === 2
directions = [[s * [1,0,0], s * [0,0,0]],
[s * [0,0,0], s * [1,0,0]]]
elseif case === 3
directions = [[s * [1,0,0], s * [0,0,0], s * [0,0,0]],
[s * [0,0,0], s * [1,0,0], s * [0,0,0]],
[s * [0,0,0], s * [0,0,0], s * [1,0,0]]]
end

params = Dict((:X, :X) => (; ε=1, σ=a / length(positions) /2^(1/6)))
V(x, p) = 4*p.ε * ((p.σ/x)^12 - (p.σ/x)^6)

(positions=positions, lattice=lattice, directions=directions, params=params, V=V)
end

# Compute phonons for a one-dimensional pairwise potential for a set of `q = 0` using
# supercell method
epolack marked this conversation as resolved.
Show resolved Hide resolved
function test_supercell_q0(; N_scell=1)
blob = prepare_system(; case=N_scell)
positions = blob.positions
lattice = blob.lattice
directions = blob.directions
params = blob.params
V = blob.V

s = DFTK.compute_inverse_lattice(lattice)
n_atoms = length(positions)

directions = [reshape(vcat([[i==j, 0.0, 0.0] for i in 1:n_atoms]...), 3, :) for j in 1:n_atoms]

Φ = Array{eltype(positions[1])}(undef, length(directions), n_atoms)
for (i, direction) in enumerate(directions)
Φ[i, :] = - ForwardDiff.derivative(0.0) do ε
new_positions = unfold(fold(positions) .+ ε .* s * direction)
forces = zeros(Vec3{complex(eltype(ε))}, length(positions))
DFTK.energy_pairwise(lattice, [:X for _ in positions], new_positions, V, params;
forces=forces, max_radius=MAX_RADIUS)
[(s * f)[1] for f in forces]
end
end
sqrt.(abs.(eigvals(Φ)))
end

# Compute phonons for a one-dimensional pairwise potential for a set of `q`-points
function test_ph_disp(; case=1)
blob = prepare_system(; case=case)
positions = blob.positions
lattice = blob.lattice
directions = blob.directions
params = blob.params
V = blob.V

pairwise_ph = (q, d; forces=nothing) ->
DFTK.energy_pairwise(lattice, [:X for _ in positions],
positions, V, params; q=[q, 0, 0],
ph_disp=d, forces=forces,
max_radius=MAX_RADIUS)

ph_bands = []
qs = -1/2:1/N_POINTS:1/2
for q in qs
as = ComplexF64[]
for d in directions
res = -ForwardDiff.derivative(0.0) do ε
forces = zeros(Vec3{complex(eltype(ε))}, length(positions))
pairwise_ph(q, ε*d; forces=forces)
[DFTK.compute_inverse_lattice(lattice)' * f for f in forces]
end
[push!(as, r[1]) for r in res]
end
M = reshape(as, length(positions), :)
@assert ≈(norm(imag.(eigvals(M))), 0.0, atol=1e-8)
push!(ph_bands, sqrt.(abs.(real(eigvals(M)))))
end
return ph_bands
end

@testset "Phonon consistency" begin
ph_bands = []
for case in [1, 2, 3]
push!(ph_bands, test_ph_disp(; case=case))
end

# Recover the same extremum for the system whatever case we test
for case in [2, 3]
@test ≈(minimum(fold(ph_bands[1])), minimum(fold(ph_bands[case])), atol=TOLERANCE)
@test ≈(maximum(fold(ph_bands[1])), maximum(fold(ph_bands[case])), atol=TOLERANCE)
end

# Test consistency between supercell method at `q = 0` and direct `q`-points computations
for N_scell in [1, 2, 3]
r_q0 = test_supercell_q0(; N_scell=N_scell)
@assert length(r_q0) == N_scell
ph_band_q0 = ph_bands[N_scell][N_POINTS÷2+1]
@test norm(r_q0 - ph_band_q0) < TOLERANCE
end
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,10 @@ Random.seed!(0)
include("forwarddiff.jl")
end

# Phonons
if "all" in TAGS
include("phonons.jl")
end

("example" in TAGS) && include("runexamples.jl")
end