Skip to content

Commit

Permalink
fix: forwarddiff support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 31, 2024
1 parent 693949b commit cdde6e1
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 161 deletions.
6 changes: 3 additions & 3 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ function SciMLBase.__solve(
)
end

linsolve = alg.ls === :qr ? LSO.QR() :
(alg.ls === :cholesky ? LSO.Cholesky() :
(alg.ls === :lsmr ? LSO.LSMR() : nothing))
linsolve = alg.linsolve === :qr ? LSO.QR() :
(alg.linsolve === :cholesky ? LSO.Cholesky() :
(alg.linsolve === :lsmr ? LSO.LSMR() : nothing))

lso_solver = if alg.alg === :lm
LSO.LevenbergMarquardt(linsolve)
Expand Down
18 changes: 15 additions & 3 deletions ext/NonlinearSolveSundialsExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
module NonlinearSolveSundialsExt

using Sundials: KINSOL

using CommonSolve: CommonSolve
using NonlinearSolveBase: NonlinearSolveBase, nonlinearsolve_forwarddiff_solve,
nonlinearsolve_dual_solution
using NonlinearSolve: DualNonlinearProblem
using NonlinearSolve: NonlinearSolve, DualNonlinearProblem
using SciMLBase: SciMLBase
using Sundials: KINSOL

function SciMLBase.__solve(prob::DualNonlinearProblem, alg::KINSOL, args...; kwargs...)
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

function SciMLBase.__init(prob::DualNonlinearProblem, alg::KINSOL, args...; kwargs...)
p = NonlinearSolveBase.nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
cache = CommonSolve.init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end

end
4 changes: 1 addition & 3 deletions lib/NonlinearSolveBase/src/linear_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...)
error("Default Julia Backsolve Operator `\\` doesn't support Preconditioners")
return NativeJLLinearSolveCache(A, b, stats)
elseif no_preconditioner && linsolve === nothing
# Non-allocating linear solve exists in StaticArrays.jl
if (A isa SMatrix || A isa WrappedArray{<:Any, <:SMatrix}) &&
Core.Compiler.return_type(\, Tuple{typeof(A), typeof(b)}) <: SArray
if (A isa SMatrix || A isa WrappedArray{<:Any, <:SMatrix})
return NativeJLLinearSolveCache(A, b, stats)
end
end
Expand Down
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ function maybe_pinv!!(workspace, A::StridedMatrix)
!issingular && return LinearAlgebra.tril!(parent(inv(A_)))
else
F = LinearAlgebra.lu(A; check = false)
if issuccess(F)
if LinearAlgebra.issuccess(F)
Ai = LinearAlgebra.inv!(F)
return convert(typeof(parent(Ai)), Ai)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Setfield: @set!

using ADTypes: ADTypes
using ArrayInterface: ArrayInterface
using LinearAlgebra: LinearAlgebra, Diagonal, dot
using LinearAlgebra: LinearAlgebra, Diagonal, dot, diagind
using StaticArraysCore: SArray

using CommonSolve: CommonSolve
Expand Down
12 changes: 9 additions & 3 deletions lib/NonlinearSolveQuasiNewton/test/core_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,14 @@ end
LiFukushimaLineSearch()
)
@testset "[OOP] u0: $(typeof(u0))" for u0 in (ones(32), @SVector(ones(2)), 1.0)
broken = Sys.iswindows() && u0 isa Vector{Float64} &&
linesearch isa BackTracking && ad isa AutoFiniteDiff

solver = LimitedMemoryBroyden(; linesearch)
sol = solve_oop(quadratic_f, u0; solver)
@test SciMLBase.successful_retcode(sol)
@test SciMLBase.successful_retcode(sol) broken=broken
err = maximum(abs, quadratic_f(sol.u, 2.0))
@test err < 1e-9
@test err<1e-9 broken=broken

cache = init(
NonlinearProblem{false}(quadratic_f, u0, 2.0), solver, abstol = 1e-9
Expand All @@ -185,11 +188,14 @@ end
@testset "[IIP] u0: $(typeof(u0))" for u0 in (ones(32),)
ad isa AutoZygote && continue

broken = Sys.iswindows() && u0 isa Vector{Float64} &&
linesearch isa BackTracking && ad isa AutoFiniteDiff

solver = LimitedMemoryBroyden(; linesearch)
sol = solve_iip(quadratic_f!, u0; solver)
@test SciMLBase.successful_retcode(sol)
err = maximum(abs, quadratic_f(sol.u, 2.0))
@test err < 1e-9
@test err<1e-9 broken=broken

cache = init(
NonlinearProblem{true}(quadratic_f!, u0, 2.0), solver, abstol = 1e-9
Expand Down
28 changes: 14 additions & 14 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ using NonlinearSolveQuasiNewton: Broyden, Klement
using SimpleNonlinearSolve: SimpleBroyden, SimpleKlement

# Default AD Support
using FiniteDiff: FiniteDiff # Default Finite Difference Method
using ForwardDiff: ForwardDiff # Default Forward Mode AD
using FiniteDiff: FiniteDiff # Default Finite Difference Method
using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD

# Sparse AD Support: Implemented via extensions
using SparseArrays: SparseArrays
Expand All @@ -39,9 +39,9 @@ using SparseMatrixColorings: SparseMatrixColorings
using BracketingNonlinearSolve: BracketingNonlinearSolve
using LineSearch: LineSearch
using LinearSolve: LinearSolve
using NonlinearSolveFirstOrder: NonlinearSolveFirstOrder
using NonlinearSolveQuasiNewton: NonlinearSolveQuasiNewton
using NonlinearSolveSpectralMethods: NonlinearSolveSpectralMethods
using NonlinearSolveFirstOrder: NonlinearSolveFirstOrder, GeneralizedFirstOrderAlgorithm
using NonlinearSolveQuasiNewton: NonlinearSolveQuasiNewton, QuasiNewtonAlgorithm
using NonlinearSolveSpectralMethods: NonlinearSolveSpectralMethods, GeneralizedDFSane
using SimpleNonlinearSolve: SimpleNonlinearSolve

const SII = SymbolicIndexingInterface
Expand All @@ -53,16 +53,16 @@ include("extension_algs.jl")

include("default.jl")

# const ALL_SOLVER_TYPES = [
# Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
# GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
# LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
# SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
# CMINPACK, PETScSNES,
# NonlinearSolvePolyAlgorithm{:NLLS, <:Any}, NonlinearSolvePolyAlgorithm{:NLS, <:Any}
# ]
const ALL_SOLVER_TYPES = [
Nothing, AbstractNonlinearSolveAlgorithm,
GeneralizedDFSane, GeneralizedFirstOrderAlgorithm, QuasiNewtonAlgorithm,
LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
CMINPACK, PETScSNES,
NonlinearSolvePolyAlgorithm
]

# include("internal/forward_diff.jl") # we need to define after the algorithms
include("forward_diff.jl")

@setup_workload begin
include("../common/nonlinear_problem_workloads.jl")
Expand Down
68 changes: 41 additions & 27 deletions src/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
const DualNonlinearProblem = NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem}
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

for algType in ALL_SOLVER_TYPES
@eval function SciMLBase.__solve(
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end
end

@concrete mutable struct NonlinearSolveForwardDiffCache
@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
cache
prob
alg
Expand All @@ -25,36 +33,41 @@ end
partials_p
end

@internal_caches NonlinearSolveForwardDiffCache :cache

function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
p = cache.p, u0 = get_u(cache.cache), kwargs...)
inner_cache = reinit_cache!(cache.cache; p = __value(p), u0 = __value(u0), kwargs...)
function InternalAPI.reinit!(
cache::NonlinearSolveForwardDiffCache, args...;
p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs...
)
inner_cache = InternalAPI.reinit!(
cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs...
)
cache.cache = inner_cache
cache.p = p
cache.values_p = __value(p)
cache.values_p = nodual_value(p)
cache.partials_p = ForwardDiff.partials(p)
return cache
end

for algType in ALL_SOLVER_TYPES
# XXX: Extend to DualNonlinearLeastSquaresProblem
@eval function SciMLBase.__init(
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
p = __value(prob.p)
newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end
end

function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
sol = solve!(cache.cache)
prob = cache.prob

uu = sol.u
Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)

z_arr = -Jᵤ \ Jₚ

Expand All @@ -65,11 +78,12 @@ function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
end

dual_soln = nonlinearsolve_dual_solution(sol.u, partials, cache.p)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p)
return SciMLBase.build_solution(
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

@inline __value(x) = x
@inline __value(x::Dual) = ForwardDiff.value(x)
@inline __value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
9 changes: 6 additions & 3 deletions test/23_test_problems_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,13 @@ end
test_on_library(problems, dicts, alg_ops, broken_tests)
end

@testitem "Broyden" setup=[RobustnessTesting] tags=[:core] begin
alg_ops = (Broyden(), Broyden(; init_jacobian = Val(:true_jacobian)),
@testitem "Broyden" setup=[RobustnessTesting] tags=[:core] retries=3 begin
alg_ops = (
Broyden(),
Broyden(; init_jacobian = Val(:true_jacobian)),
Broyden(; update_rule = Val(:bad_broyden)),
Broyden(; init_jacobian = Val(:true_jacobian), update_rule = Val(:bad_broyden)))
Broyden(; init_jacobian = Val(:true_jacobian), update_rule = Val(:bad_broyden))
)

broken_tests = Dict(alg => Int[] for alg in alg_ops)
if Sys.isapple()
Expand Down
3 changes: 1 addition & 2 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
SOLVERS = (
NewtonRaphson(),
LevenbergMarquardt(; linsolve = QRFactorization()),
# XXX: Fails currently
# LevenbergMarquardt(; linsolve = KrylovJL_GMRES()),
LevenbergMarquardt(; linsolve = KrylovJL_GMRES()),
PseudoTransient(),
Klement(),
Broyden(; linesearch = LiFukushimaLineSearch()),
Expand Down
Loading

0 comments on commit cdde6e1

Please sign in to comment.