Skip to content

Commit

Permalink
auto switch to finitediff for inplace problems
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 20, 2023
1 parent 7e26d18 commit 83c0723
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 52 deletions.
16 changes: 13 additions & 3 deletions src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,15 @@ function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip

g₀ = _mutable_zero(u)

autodiff = if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote)
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. Falling back to finite differencing."
AutoFiniteDiff()

Check warning on line 96 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L95-L96

Added lines #L95 - L96 were not covered by tests
else
ls.autodiff
end

function g!(u, fu)
op = VecJac((args...) -> f(args..., p), u)
op = VecJac((args...) -> f(args..., p), u; autodiff)
if iip
mul!(g₀, op, fu)
return g₀
Expand Down Expand Up @@ -134,13 +141,16 @@ function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip
end

function perform_linesearch!(cache::LineSearchCache, u, du)
cache.ls.method isa Static && return (cache.α, cache.f(u, du, cache.α))
cache.ls.method isa Static && return cache.α

ϕ = cache.ϕ(u, du)
= cache.(u, du)
ϕdϕ = cache.ϕdϕ(u, du)

ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u)))

return cache.ls.method(ϕ, cache.(u, du), cache.ϕdϕ(u, du), cache.α, ϕ₀, dϕ₀)
# This case is sometimes possible for large optimization problems
dϕ₀ 0 && return cache.α

return first(cache.ls.method(ϕ, cache.(u, du), cache.ϕdϕ(u, du), cache.α, ϕ₀, dϕ₀))
end
5 changes: 3 additions & 2 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ function perform_step!(cache::NewtonRaphsonCache{true})
cache.linsolve = linres.cache

# Line Search
α, _ = perform_linesearch!(cache.lscache, u, du)
α = perform_linesearch!(cache.lscache, u, du)
@. u = u - α * du
f(cache.fu1, u, p)

cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
Expand All @@ -118,7 +119,7 @@ function perform_step!(cache::NewtonRaphsonCache{false})
end

# Line Search
α, _fu = perform_linesearch!(cache.lscache, u, cache.du)
α = perform_linesearch!(cache.lscache, u, cache.du)
cache.u = @. u - α * cache.du # `u` might not support mutation
cache.fu1 = f(cache.u, p)

Expand Down
107 changes: 60 additions & 47 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ end

@testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([
1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES())
ad isa AutoZygote && continue
if prec === :Random
prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing)
end
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec, linesearch)
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec,
linesearch)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

Expand All @@ -67,25 +69,30 @@ end
end

if VERSION v"1.9"
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
@testset "[OOP] [Immutable AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end
end

@testset "[OOP] [Scalar AD] p: $(p)" for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u res_true
@testset "[OOP] [Scalar AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p)
1 / (2 * sqrt(p))
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)
1 / (2 * sqrt(p))
end

if VERSION v"1.9"
Expand Down Expand Up @@ -162,33 +169,34 @@ end
end

if VERSION v"1.9"
@testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes,
p in 1.0:0.1:100.0
@testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p;
radius_update_scheme)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) 1 / (2 * sqrt(p))
end
end
end

@testset "[OOP] [Scalar AD] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p;
res = benchmark_nlsolve_oop(quadratic_f, oftype(p, 1.0), p;
radius_update_scheme)
res_true = sqrt(p)
all(res.u . res_true)
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) 1 / (2 * sqrt(p))
oftype(p, 1.0),
p; radius_update_scheme).u, p) 1 / (2 * sqrt(p))
end
end

@testset "[OOP] [Scalar AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes,
p in 1.0:0.1:100.0

@test begin
res = benchmark_nlsolve_oop(quadratic_f, oftype(p, 1.0), p;
radius_update_scheme)
res_true = sqrt(p)
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, oftype(p, 1.0),
p; radius_update_scheme).u, p) 1 / (2 * sqrt(p))
end

if VERSION v"1.9"
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
Expand Down Expand Up @@ -316,25 +324,30 @@ end
end

if VERSION v"1.9"
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
@testset "[OOP] [Immutable AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end
end

@testset "[OOP] [Scalar AD] p: $(p)" for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u res_true
@testset "[OOP] [Scalar AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p)
1 / (2 * sqrt(p))
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)
1 / (2 * sqrt(p))
end

if VERSION v"1.9"
Expand Down

0 comments on commit 83c0723

Please sign in to comment.