Skip to content

Commit

Permalink
test: NLLS forwarddiff rules testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 7, 2024
1 parent 2dcdbbd commit e032a13
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 6 deletions.
6 changes: 4 additions & 2 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ function CommonSolve.solve(
end

function CommonSolve.solve(
prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
Expand All @@ -86,7 +87,8 @@ function CommonSolve.solve(
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function simplenonlinearsolve_solve_up(prob::ImmutableNonlinearProblem, sensealg, u0,
function simplenonlinearsolve_solve_up(
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
(u0_changed || p_changed) && (prob = remake(prob; u0, p))
return SciMLBase.__solve(prob, alg, args...; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function SciMLBase.__solve(

@bb xo = similar(x)
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
safe_similar(fx) : nothing
safe_similar(fx) : fx
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)

Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/src/trust_region.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi

@bb xo = copy(x)
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
safe_similar(fx) : nothing
safe_similar(fx) : fx
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)

Expand Down
4 changes: 2 additions & 2 deletions lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
end
if extras isa AnalyticJacobian
if SciMLBase.isinplace(prob)
prob.jac(J, x, prob.p)
prob.f.jac(J, x, prob.p)
return J
else
return prob.jac(x, prob.p)
return prob.f.jac(x, prob.p)
end
end
if SciMLBase.isinplace(prob)
Expand Down
114 changes: 114 additions & 0 deletions lib/SimpleNonlinearSolve/test/core/forward_diff_tests.jl
Original file line number Diff line number Diff line change
@@ -1 +1,115 @@
@testitem "ForwardDiff.jl Integration NonlinearLeastSquaresProblem" tags=[:core] begin
using ForwardDiff, FiniteDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra,
Zygote, ReverseDiff
using DifferentiationInterface

const DI = DifferentiationInterface

true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])

θ_true = [1.0, 0.1, 2.0, 0.5]
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
y_target = true_function(x, θ_true)

loss_function(θ, p) = true_function(p, θ) .- y_target

loss_function_jac(θ, p) = ForwardDiff.jacobian(Base.Fix2(loss_function, p), θ)

loss_function_vjp(v, θ, p) = reshape(vec(v)' * loss_function_jac(θ, p), size(θ))

function loss_function!(resid, θ, p)
= true_function(p, θ)
@. resid =- y_target
return
end

function loss_function_jac!(J, θ, p)
J .= ForwardDiff.jacobian-> loss_function(θ, p), θ)
return
end

function loss_function_vjp!(vJ, v, θ, p)
vec(vJ) .= reshape(vec(v)' * loss_function_jac(θ, p), size(θ))
return
end

θ_init = θ_true .+ 0.1

@testset for alg in (
SimpleGaussNewton(),
SimpleGaussNewton(; autodiff = AutoForwardDiff()),
SimpleGaussNewton(; autodiff = AutoFiniteDiff()),
SimpleGaussNewton(; autodiff = AutoReverseDiff())
)
function obj_1(p)
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

function obj_2(p)
ff = NonlinearFunction{false}(
loss_function; resid_prototype = zeros(length(y_target)))
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

function obj_3(p)
ff = NonlinearFunction{false}(loss_function; vjp = loss_function_vjp)
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

finitediff = DI.gradient(obj_1, AutoFiniteDiff(), x)

fdiff1 = DI.gradient(obj_1, AutoForwardDiff(), x)
fdiff2 = DI.gradient(obj_2, AutoForwardDiff(), x)
fdiff3 = DI.gradient(obj_3, AutoForwardDiff(), x)

@test finitedifffdiff1 atol=1e-5
@test finitedifffdiff2 atol=1e-5
@test finitedifffdiff3 atol=1e-5
@test fdiff1 fdiff2 fdiff3

function obj_4(p)
prob_iip = NonlinearLeastSquaresProblem(
NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target))),
θ_init,
p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

function obj_5(p)
ff = NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target)),
jac = loss_function_jac!)
prob_iip = NonlinearLeastSquaresProblem(ff, θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

function obj_6(p)
ff = NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target)),
vjp = loss_function_vjp!)
prob_iip = NonlinearLeastSquaresProblem(ff, θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

finitediff = DI.gradient(obj_4, AutoFiniteDiff(), x)

fdiff4 = DI.gradient(obj_4, AutoForwardDiff(), x)
fdiff5 = DI.gradient(obj_5, AutoForwardDiff(), x)
fdiff6 = DI.gradient(obj_6, AutoForwardDiff(), x)

@test finitedifffdiff4 atol=1e-5
@test finitedifffdiff5 atol=1e-5
@test finitedifffdiff6 atol=1e-5
@test fdiff4 fdiff5 fdiff6
end
end

0 comments on commit e032a13

Please sign in to comment.