Skip to content

Commit

Permalink
test: move tests for QuasiNewton solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 30, 2024
1 parent c73690e commit 5b1d845
Show file tree
Hide file tree
Showing 16 changed files with 440 additions and 276 deletions.
71 changes: 71 additions & 0 deletions .github/workflows/CI_NonlinearSolveQuasiNewton.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
name: CI (NonlinearSolveQuasiNewton)

on:
pull_request:
branches:
- master
paths:
- "lib/NonlinearSolveQuasiNewton/**"
- ".github/workflows/CI_NonlinearSolveQuasiNewton.yml"
- "lib/NonlinearSolveBase/**"
- "lib/SciMLJacobianOperators/**"
push:
branches:
- master

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- "lts"
- "1"
os:
- ubuntu-latest
- macos-latest
- windows-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
- uses: actions/cache@v4
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- name: "Install Dependencies and Run Tests"
run: |
import Pkg
Pkg.Registry.update()
# Install packages present in subdirectories
dev_pks = Pkg.PackageSpec[]
for path in ("lib/SciMLJacobianOperators", "lib/NonlinearSolveBase")
push!(dev_pks, Pkg.PackageSpec(; path))
end
Pkg.develop(dev_pks)
Pkg.instantiate()
Pkg.test(; coverage="user")
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/NonlinearSolveQuasiNewton {0}
- uses: julia-actions/julia-processcoverage@v1
with:
directories: lib/NonlinearSolveQuasiNewton/src,lib/NonlinearSolveBase/src,lib/NonlinearSolveBase/ext,lib/SciMLJacobianOperators/src
- uses: codecov/codecov-action@v4
with:
file: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
verbose: true
fail_ci_if_error: true
13 changes: 8 additions & 5 deletions common/nlls_problem_workloads.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
using SciMLBase: NonlinearLeastSquaresProblem, NonlinearFunction
using SciMLBase: NonlinearLeastSquaresProblem, NonlinearFunction, NoSpecialize

nonlinear_functions = (
(NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
(NonlinearFunction{false, NoSpecialize}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
(
NonlinearFunction{true}(
NonlinearFunction{false, NoSpecialize}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
[0.1, 0.1]
),
(
NonlinearFunction{true, NoSpecialize}(
(du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)
),
[0.1, 0.0]
),
(
NonlinearFunction{true}(
NonlinearFunction{true, NoSpecialize}(
(du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), resid_prototype = zeros(4)
),
[0.1, 0.1]
Expand Down
8 changes: 4 additions & 4 deletions common/nonlinear_problem_workloads.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using SciMLBase: NonlinearProblem, NonlinearFunction
using SciMLBase: NonlinearProblem, NonlinearFunction, NoSpecialize

nonlinear_functions = (
(NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
(NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1])
(NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1),
(NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), [0.1]),
(NonlinearFunction{true, NoSpecialize}((du, u, p) -> du .= u .* u .- p), [0.1])
)

nonlinear_problems = NonlinearProblem[]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ function (cache::LinearSolveJLCache)(;
if cache.precs === nothing
Pl, Pr = nothing, nothing
else
Pl, Pr = cache.precs(cache.lincache.A, du, linu, p, nothing,
A !== nothing, Plprev, Prprev, cachedata)
Pl, Pr = cache.precs(
cache.lincache.A, du, linu, p, nothing,
A !== nothing, Plprev, Prprev, cachedata
)
end

if Pl !== nothing || Pr !== nothing
Expand Down
11 changes: 7 additions & 4 deletions lib/NonlinearSolveBase/src/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,13 @@ function update_trace!(
if show_now || store_now
entry = if trace.trace_level.trace_mode isa Val{:minimal}
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, missing, missing)
elseif trace.trace_level.trace_mode isa Val{:condition_number}
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, J, missing)
else
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, J, u)
J = convert(AbstractArray, J)
if trace.trace_level.trace_mode isa Val{:condition_number}
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, J, missing)
else
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, J, u)
end
end
show_now && show(stdout, MIME"text/plain"(), entry)
store_now && push!(trace.history, entry)
Expand All @@ -224,7 +227,7 @@ function update_trace!(cache, α = true; uses_jac_inverse = Val(false))
trace, cache.nsteps + 1, get_u(cache), get_fu(cache), nothing, cache.du, α
)
else
J = uses_jac_inverse isa Val{true} ? pinv(J) : J
J = uses_jac_inverse isa Val{true} ? Utils.Pinv(cache.J) : cache.J
update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache), J, cache.du, α)
end
end
18 changes: 17 additions & 1 deletion lib/NonlinearSolveBase/src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Utils

using ArrayInterface: ArrayInterface
using ConcreteStructs: @concrete
using FastClosures: @closure
using LinearAlgebra: LinearAlgebra, Diagonal, Symmetric, norm, dot, cond, diagind, pinv
using MaybeInplace: @bb
Expand All @@ -15,6 +16,17 @@ is_extension_loaded(::Val) = false

fast_scalar_indexing(xs...) = all(ArrayInterface.fast_scalar_indexing, xs)

@concrete struct Pinv
J
end

function Base.convert(::Type{AbstractArray}, A::Pinv)
hasmethod(pinv, Tuple{typeof(A.J)}) && return pinv(A.J)
@warn "`pinv` not defined for $(typeof(A.J)). Jacobian will not be inverted when \
tracing." maxlog=1
return A.J
end

function nonallocating_isapprox(x::Number, y::Number; atol = false,
rtol = atol > 0 ? false : sqrt(eps(promote_type(typeof(x), typeof(y)))))
return isapprox(x, y; atol, rtol)
Expand Down Expand Up @@ -223,7 +235,11 @@ end

make_identity!!(::T, α) where {T <: Number} = T(α)
function make_identity!!(A::AbstractVector{T}, α) where {T}
@bb @. A = T(α)
if ArrayInterface.can_setindex(A)
@. A = α
else
A = one.(A) .* α
end
return A
end
function make_identity!!(::SMatrix{S1, S2, T, L}, α) where {S1, S2, T, L}
Expand Down
22 changes: 19 additions & 3 deletions lib/NonlinearSolveQuasiNewton/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
Expand All @@ -20,16 +19,22 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[compat]
ADTypes = "1.9.0"
Aqua = "0.8"
ArrayInterface = "7.16.0"
BenchmarkTools = "1.5.0"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.155.3"
Enzyme = "0.13.12"
ExplicitImports = "1.5"
FiniteDiff = "2.26.0"
ForwardDiff = "0.10.36"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LineSearch = "0.1.4"
LinearAlgebra = "1.11.0"
LineSearches = "7.3.0"
LinearAlgebra = "1.10"
LinearSolve = "2.36.1"
MaybeInplace = "0.1.4"
NonlinearProblemLibrary = "0.1.2"
Expand All @@ -41,20 +46,31 @@ Reexport = "1"
SciMLBase = "2.54"
SciMLOperators = "0.3.11"
StableRNGs = "1"
StaticArrays = "1.9.8"
StaticArraysCore = "1.4.3"
Test = "1.10"
Zygote = "0.6.72"
julia = "1.10"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ExplicitImports", "Hwloc", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "Test"]
test = ["ADTypes", "Aqua", "BenchmarkTools", "Enzyme", "ExplicitImports", "FiniteDiff", "ForwardDiff", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test", "Zygote"]
14 changes: 6 additions & 8 deletions lib/NonlinearSolveQuasiNewton/src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,14 @@ function InternalAPI.solve!(
@bb @. cache.dfu = fu - cache.dfu
J⁻¹_diag = Utils.restructure(cache.dfu, diag(J⁻¹))
if cache.rule isa GoodBroydenUpdateRule
@bb @. J⁻¹_diag = J⁻¹_diag * cache.dfu * du
denom = sum(J⁻¹_diag)
@bb @. J⁻¹_diag = J⁻¹_diag +
(du - J⁻¹_diag * cache.dfu) * du * J⁻¹_diag /
ifelse(iszero(denom), T(1e-5), denom)
@bb @. cache.J⁻¹dfu = J⁻¹_diag * cache.dfu * du
denom = sum(cache.J⁻¹dfu)
@bb @. J⁻¹_diag += (du - cache.J⁻¹dfu) * du * J⁻¹_diag /
ifelse(iszero(denom), T(1e-5), denom)
else
denom = cache.internalnorm(cache.dfu)^2
@bb @. J⁻¹_diag = J⁻¹_diag +
(du - J⁻¹_diag * cache.dfu) * cache.dfu /
ifelse(iszero(denom), T(1e-5), denom)
@bb @. J⁻¹_diag += (du - J⁻¹_diag * cache.dfu) * cache.dfu /
ifelse(iszero(denom), T(1e-5), denom)
end
@bb copyto!(cache.dfu, fu)
return Diagonal(vec(J⁻¹_diag))
Expand Down
12 changes: 8 additions & 4 deletions lib/NonlinearSolveQuasiNewton/src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ function InternalAPI.init(
J = BroydenLowRankJacobian(fu, u; alg.threshold, alpha = α)
else
threshold = min(Utils.unwrap_val(alg.threshold), maxiters)
if threshold > length(u)
@warn "`threshold` is larger than the size of the state, which may cause \
numerical instability. Consider reducing `threshold`."
end
J = BroydenLowRankJacobian(fu, u; threshold, alpha = α)
end
return InitializedApproximateJacobianCache(
Expand Down Expand Up @@ -240,9 +244,9 @@ function LinearAlgebra.mul!(y::AbstractVector, J::BroydenLowRankJacobian, x::Abs
@. y = -J.alpha * x
return y
end
_, U, Vᵀ = get_components(J)
cache, U, Vᵀ = get_components(J)
@bb cache = Vᵀ × x
mul!(y, U, cache)
LinearAlgebra.mul!(y, U, cache)
@bb @. y -= J.alpha * x
return y
end
Expand All @@ -258,9 +262,9 @@ function LinearAlgebra.mul!(y::AbstractVector, x::AbstractVector, J::BroydenLowR
@. y = -J.alpha * x
return y
end
_, U, Vᵀ = get_components(J)
cache, U, Vᵀ = get_components(J)
@bb cache = transpose(U) × x
mul!(y, transpose(Vᵀ), cache)
LinearAlgebra.mul!(y, transpose(Vᵀ), cache)
@bb @. y -= J.alpha * x
return y
end
Expand Down
2 changes: 1 addition & 1 deletion lib/NonlinearSolveQuasiNewton/src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function InternalAPI.solve!(
T = eltype(u)
J = Utils.restructure(u, diag(J))
@bb @. cache.Jdu = (J^2) * (du^2)
@bb @. J += ((fu - cache.fu_cache - cache.Jdu) /
@bb @. J += ((fu - cache.fu_cache - J * du) /
ifelse(iszero(cache.Jdu), T(1e-5), cache.Jdu)) * du * (J^2)
@bb copyto!(cache.fu_cache, fu)
return Diagonal(vec(J))
Expand Down
53 changes: 25 additions & 28 deletions lib/NonlinearSolveQuasiNewton/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,34 +95,31 @@ end
kwargs
end

# XXX: Implement
# function __reinit_internal!(cache::QuasiNewtonCache{INV, GB, iip},
# args...; p = cache.p, u0 = cache.u, alias_u0::Bool = false,
# maxiters = 1000, maxtime = nothing, kwargs...) where {INV, GB, iip}
# if iip
# recursivecopy!(cache.u, u0)
# cache.prob.f(cache.fu, cache.u, p)
# else
# cache.u = __maybe_unaliased(u0, alias_u0)
# set_fu!(cache, cache.prob.f(cache.u, p))
# end
# cache.p = p

# __reinit_internal!(cache.stats)
# cache.nsteps = 0
# cache.nresets = 0
# cache.steps_since_last_reset = 0
# cache.maxiters = maxiters
# cache.maxtime = maxtime
# cache.total_time = 0.0
# cache.force_stop = false
# cache.force_reinit = false
# cache.retcode = ReturnCode.Default

# reset!(cache.trace)
# reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
# reset_timer!(cache.timer)
# end
function InternalAPI.reinit_self!(
cache::QuasiNewtonCache, args...; p = cache.p, u0 = cache.u,
alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...
)
Utils.reinit_common!(cache, u0, p, alias_u0)

InternalAPI.reinit!(cache.stats)
cache.nsteps = 0
cache.nresets = 0
cache.steps_since_last_reset = 0
cache.maxiters = maxiters
cache.maxtime = maxtime
cache.total_time = 0.0
cache.force_stop = false
cache.force_reinit = false
cache.retcode = ReturnCode.Default

NonlinearSolveBase.reset!(cache.trace)
SciMLBase.reinit!(
cache.termination_cache, NonlinearSolveBase.get_fu(cache),
NonlinearSolveBase.get_u(cache); kwargs...
)
NonlinearSolveBase.reset_timer!(cache.timer)
return
end

NonlinearSolveBase.@internal_caches(QuasiNewtonCache,
:initialization_cache, :descent_cache, :linesearch_cache, :trustregion_cache,
Expand Down
Loading

0 comments on commit 5b1d845

Please sign in to comment.