diff --git a/src/terms/pairwise.jl b/src/terms/pairwise.jl index e0f28f1d8..77e05c35f 100644 --- a/src/terms/pairwise.jl +++ b/src/terms/pairwise.jl @@ -1,8 +1,7 @@ # We cannot use `LinearAlgebra.norm` with complex numbers due to the need to use its # analytic continuation function norm_cplx(x) - # TODO: ForwardDiff bug (https://github.com/JuliaDiff/ForwardDiff.jl/issues/324) - sqrt(sum(x.*x)) + sqrt(sum(x.^2)) end struct PairwisePotential @@ -17,10 +16,10 @@ Lennard—Jones terms. The potential is dependent on the distance between to atomic positions and the pairwise atomic types: For a distance `d` between to atoms `A` and `B`, the potential is `V(d, params[(A, B)])`. -The parameters `max_radius` is of `1000` by default, and gives the maximum (Cartesian) +The parameters `max_radius` is of `100` by default, and gives the maximum (Cartesian) distance between nuclei for which we consider interactions. """ -function PairwisePotential(V, params; max_radius=1000) +function PairwisePotential(V, params; max_radius=100) params = Dict(minmax(key[1], key[2]) => value for (key, value) in params) PairwisePotential(V, params, max_radius) end @@ -43,8 +42,7 @@ end @timing "forces: Pairwise" function compute_forces(term::TermPairwisePotential, basis::PlaneWaveBasis{T}, ψ, occ; kwargs...) where {T} - TT = promote_type(T, eltype(basis.model.positions[1])) - forces = zero(TT, basis.model.positions) + forces = zero(basis.model.positions) energy_pairwise(basis.model, term.V, term.params; max_radius=term.max_radius, forces=forces, kwargs...) forces @@ -65,15 +63,13 @@ end # This could be factorised with Ewald, but the use of `symbols` would slow down the # computationally intensive Ewald sums. So we leave it as it for now. -# TODO: *Beware* of using ForwardDiff to derive this function with complex numbers, use -# multiplications and not powers (https://github.com/JuliaDiff/ForwardDiff.jl/issues/324). # `q` is the phonon `q`-point (`Vec3`), and `ph_disp` a list of `Vec3` displacements to # compute the Fourier transform of the force constant matrix. function energy_pairwise(lattice, symbols, positions, V, params; - max_radius=1000, forces=nothing, ph_disp=nothing, q=nothing) + max_radius=100, forces=nothing, ph_disp=nothing, q=nothing) @assert length(symbols) == length(positions) - T = eltype(lattice) + T = eltype(positions[1]) if ph_disp !== nothing @assert q !== nothing T = promote_type(complex(T), eltype(ph_disp[1])) @@ -135,9 +131,8 @@ function energy_pairwise(lattice, symbols, positions, V, params; sum_pairwise += energy_contribution if forces !== nothing dE_ddist = ForwardDiff.derivative(real(zero(eltype(dist)))) do ε - res = V(dist + ε, param_ij) - [real(res), imag(res)] - end |> x -> complex(x...) + V(dist + ε, param_ij) + end dE_dti = lattice' * ((dE_ddist / dist) * Δr) # We need to "break" the symmetry for phonons; at equilibrium, expect # the forces to be zero at machine precision. diff --git a/src/workarounds/forwarddiff_rules.jl b/src/workarounds/forwarddiff_rules.jl index 0a1ec21c4..0acfe74c1 100644 --- a/src/workarounds/forwarddiff_rules.jl +++ b/src/workarounds/forwarddiff_rules.jl @@ -230,3 +230,32 @@ function Smearing.occupation(S::Smearing.FermiDirac, d::ForwardDiff.Dual{T}) whe end ForwardDiff.Dual{T}(Smearing.occupation(S, x), ∂occ * ForwardDiff.partials(d)) end + +# Workarounds for issue https://github.com/JuliaDiff/ForwardDiff.jl/issues/324 +ForwardDiff.derivative(f, x::Complex) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number (does not support Wirtinger derivatives). Separate real and imaginary parts of the input.")) +@inline ForwardDiff.extract_derivative(::Type{T}, y::Complex) where {T} = zero(y) +@inline function ForwardDiff.extract_derivative(::Type{T}, y::Complex{TD}) where {T, TD <: ForwardDiff.Dual} + complex(ForwardDiff.partials(T, real(y), 1), ForwardDiff.partials(T, imag(y), 1)) +end +function Base.:^(x::Complex{ForwardDiff.Dual{T,V,N}}, y::Complex{ForwardDiff.Dual{T,V,N}}) where {T,V,N} + xx = complex(ForwardDiff.value(real(x)), ForwardDiff.value(imag(x))) + yy = complex(ForwardDiff.value(real(y)), ForwardDiff.value(imag(y))) + dx = complex.(ForwardDiff.partials(real(x)), ForwardDiff.partials(imag(x))) + dy = complex.(ForwardDiff.partials(real(y)), ForwardDiff.partials(imag(y))) + + expv = xx^yy + ∂expv∂x = yy * xx^(yy-1) + ∂expv∂y = log(xx) * expv + dxexpv = ∂expv∂x * dx + # TODO: Fishy and should be checked, but seems to catch most cases + if iszero(xx) && ForwardDiff.isconstant(real(y)) && ForwardDiff.isconstant(imag(y)) && imag(y) === zero(imag(y)) && real(y) > 0 + dexpv = zero(expv) + elseif iszero(xx) + throw(DomainError(x, "mantissa cannot be zero for complex exponentiation")) + else + dyexpv = ∂expv∂y * dy + dexpv = dxexpv + dyexpv + end + complex(ForwardDiff.Dual{T,V,N}(real(expv), ForwardDiff.Partials{N,V}(tuple(real(dexpv)...))), + ForwardDiff.Dual{T,V,N}(imag(expv), ForwardDiff.Partials{N,V}(tuple(imag(dexpv)...)))) +end