Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Towards a cleaner and more maintainable internals of NonlinearSolve.jl #203

Merged
merged 19 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
style = "sciml"
format_markdown = true
format_markdown = true
annotate_untyped_fields_with_any = false
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ Manifest.toml
docs/src/assets/Project.toml

.vscode
wip
13 changes: 10 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "1.10.0"
version = "2.0.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Expand All @@ -25,30 +27,35 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
ArrayInterface = "6.0.24, 7"
DiffEqBase = "6"
EnumX = "1"
Enzyme = "0.11"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
LinearSolve = "2"
PrecompileTools = "1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "1.92.4"
SciMLBase = "1.97"
SimpleNonlinearSolve = "0.1"
SparseDiffTools = "1, 2"
StaticArraysCore = "1.4"
UnPack = "1.0"
Zygote = "0.6"
julia = "1.6"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools"]
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
BenchmarkTools = "1"
Documenter = "0.27"
LinearSolve = "2"
NonlinearSolve = "1"
NonlinearSolve = "1, 2"
NonlinearSolveMINPACK = "0.1"
SciMLNLSolve = "0.1"
SimpleNonlinearSolve = "0.1.5"
Expand Down
61 changes: 31 additions & 30 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
module NonlinearSolve
if isdefined(Base, :Experimental) &&
isdefined(Base.Experimental, Symbol("@max_methods"))

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods"))
@eval Base.Experimental.@max_methods 1
end
using Reexport
using UnPack: @unpack
using FiniteDiff, ForwardDiff
using ForwardDiff: Dual
using LinearAlgebra
using StaticArraysCore
using RecursiveArrayTools
import EnumX
import ArrayInterface
import LinearSolve
using DiffEqBase
using SparseDiffTools

@reexport using SciMLBase
using SciMLBase: NLStats
@reexport using SimpleNonlinearSolve

import SciMLBase: _unwrap_val

abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <:
AbstractNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem,
alg::AbstractNonlinearSolveAlgorithm, args...;
kwargs...)

using DiffEqBase, LinearAlgebra, LinearSolve, SparseDiffTools
import ForwardDiff

import ADTypes: AbstractFiniteDifferencesMode
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable
import ConcreteStructs: @concrete
import EnumX: @enumx
import ForwardDiff: Dual
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
import RecursiveArrayTools: ArrayPartition,
AbstractVectorOfArray, recursivecopy!, recursivefill!
import Reexport: @reexport
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
import StaticArraysCore: StaticArray, SVector, SArray, MArray
import UnPack: @unpack

@reexport using ADTypes, SciMLBase, SimpleNonlinearSolve

const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}

abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm,
args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
sol = solve!(cache)
return solve!(cache)
end

include("utils.jl")
Expand All @@ -46,7 +47,7 @@ PrecompileTools.@compile_workload begin
for T in (Float32, Float64)
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

precompile_algs = if VERSION >= v"1.7"
precompile_algs = if VERSION v"1.7"
(NewtonRaphson(), TrustRegion(), LevenbergMarquardt())
else
(NewtonRaphson(),)
Expand Down
65 changes: 41 additions & 24 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,61 @@
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
f = prob.f
p = value(prob.p)

u0 = value(prob.u0)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
if p isa Number
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
else
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
end
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)

z_arr = -inv(f_x) * f_p

f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
pp = prob.p
sumfun = let f_x′ = -f_x
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
sumfun = ((z, p),) -> [zᵢ * ForwardDiff.partials(p) for zᵢ in z]
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
if uu isa Number
partials = sum(sumfun, zip(z_arr, pp))
else
partials = sum(sumfun, zip(eachcol(z_arr), pp))
end
partials = sum(sumfun, zip(f_p, pp))

return sol, partials
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:Dual{T, V, P}},
alg::AbstractNewtonAlgorithm,
args...; kwargs...) where {iip, T, V, P}
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
iip, <:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractNewtonAlgorithm,
args...;

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
iip, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end

function scalar_nlsolve_∂f_∂p(f, u, p)
ff = p isa Number ? ForwardDiff.derivative :
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
return ff(Base.Fix1(f, u), p)
end

function scalar_nlsolve_∂f_∂u(f, u, p)
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
return ff(Base.Fix2(f, p), u)
end

function scalar_nlsolve_dual_soln(u::Number, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return Dual{T, V, P}(u, partials[1])
end

function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
end
Loading