Skip to content

Commit

Permalink
feat: support NLLS forward AD
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 7, 2024
1 parent 7f15b30 commit c91b973
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 8 deletions.
104 changes: 97 additions & 7 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
module NonlinearSolveBaseForwardDiffExt

using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using CommonSolve: solve
using DifferentiationInterface: DifferentiationInterface, Constant
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra: mul!
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem,
NonlinearLeastSquaresProblem, remake
NonlinearProblem, NonlinearLeastSquaresProblem, remake

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils

const DI = DifferentiationInterface

function NonlinearSolveBase.additional_incompatible_backend_check(
prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
return !ForwardDiff.can_dual(eltype(prob.u0))
Expand Down Expand Up @@ -50,22 +54,108 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
return sol, partials
end

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
p = Utils.value(prob.p)
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
sol = solve(newprob, alg, args...; kwargs...)
uu = sol.u

# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
resid = prob.f(u, p)
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
end
end
elseif SciMLBase.has_jac(prob.f)
if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
J = Utils.safe_similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
end
end
else
# For small problems, nesting ForwardDiff is actually quite fast
autodiff = length(uu) + length(sol.resid) 50 ?
NonlinearSolveBase.select_reverse_mode_autodiff(prob, nothing) :
AutoForwardDiff()

if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
# Using `Constant` lead to dual ordering issues
ff = @closure (du, u) -> prob.f(du, u, p)
resid2 = copy(resid)
DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,))
@. du *= 2
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
v = prob.f(u, p)
# Using `Constant` lead to dual ordering issues
ff = Base.Fix2(prob.f, p)
res = only(DI.pullback(ff, autodiff, u, (v,)))
ArrayInterface.can_setindex(res) || return 2 .* res
@. res *= 2
return res
end
end
end

Jₚ = nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p)
Jᵤ = nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p)
z = -Jᵤ \ Jₚ
pp = prob.p
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)

if uu isa Number
partials = sum(sumfun, zip(z, pp))
elseif p isa Number
partials = sumfun((z, pp))
else
partials = sum(sumfun, zip(eachcol(z), pp))
end

return sol, partials
end

function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
if SciMLBase.isinplace(prob)
f = @closure p -> begin
f2 = @closure p -> begin
du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p)))
f(du, u, p)
return du
end
else
f = Base.Fix1(f, u)
f2 = Base.Fix1(f, u)
end
if p isa Number
return Utils.safe_reshape(ForwardDiff.derivative(f, p), :, 1)
return Utils.safe_reshape(ForwardDiff.derivative(f2, p), :, 1)
elseif u isa Number
return Utils.safe_reshape(ForwardDiff.gradient(f, p), 1, :)
return Utils.safe_reshape(ForwardDiff.gradient(f2, p), 1, :)
else
return ForwardDiff.jacobian(f, p)
return ForwardDiff.jacobian(f2, p)
end
end

Expand Down
13 changes: 13 additions & 0 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ function CommonSolve.solve(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function CommonSolve.solve(
prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{
<:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function CommonSolve.solve(
prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
Expand Down
3 changes: 2 additions & 1 deletion lib/SimpleNonlinearSolve/src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ end
const SimpleGaussNewton = SimpleNewtonRaphson

function SciMLBase.__solve(
prob::ImmutableNonlinearProblem, alg::SimpleNewtonRaphson, args...;
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
alg::SimpleNewtonRaphson, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = Utils.maybe_unaliased(prob.u0, alias_u0)
Expand Down

0 comments on commit c91b973

Please sign in to comment.