diff --git a/src/linesearch.jl b/src/linesearch.jl index 3890f8230..30861e14b 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -91,8 +91,15 @@ function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip g₀ = _mutable_zero(u) + autodiff = if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote) + @warn "Attempting to use Zygote.jl for linesearch on an in-place problem. Falling back to finite differencing." + AutoFiniteDiff() + else + ls.autodiff + end + function g!(u, fu) - op = VecJac((args...) -> f(args..., p), u) + op = VecJac((args...) -> f(args..., p), u; autodiff) if iip mul!(g₀, op, fu) return g₀ @@ -134,7 +141,7 @@ function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip end function perform_linesearch!(cache::LineSearchCache, u, du) - cache.ls.method isa Static && return (cache.α, cache.f(u, du, cache.α)) + cache.ls.method isa Static && return cache.α ϕ = cache.ϕ(u, du) dϕ = cache.dϕ(u, du) @@ -142,5 +149,8 @@ function perform_linesearch!(cache::LineSearchCache, u, du) ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u))) - return cache.ls.method(ϕ, cache.dϕ(u, du), cache.ϕdϕ(u, du), cache.α, ϕ₀, dϕ₀) + # This case is sometimes possible for large optimization problems + dϕ₀ ≥ 0 && return cache.α + + return first(cache.ls.method(ϕ, cache.dϕ(u, du), cache.ϕdϕ(u, du), cache.α, ϕ₀, dϕ₀)) end diff --git a/src/raphson.jl b/src/raphson.jl index d01881dc4..8297f92fe 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -93,8 +93,9 @@ function perform_step!(cache::NewtonRaphsonCache{true}) cache.linsolve = linres.cache # Line Search - α, _ = perform_linesearch!(cache.lscache, u, du) + α = perform_linesearch!(cache.lscache, u, du) @. u = u - α * du + f(cache.fu1, u, p) cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true) cache.stats.nf += 1 @@ -118,7 +119,7 @@ function perform_step!(cache::NewtonRaphsonCache{false}) end # Line Search - α, _fu = perform_linesearch!(cache.lscache, u, cache.du) + α = perform_linesearch!(cache.lscache, u, cache.du) cache.u = @. u - α * cache.du # `u` might not support mutation cache.fu1 = f(cache.u, p) diff --git a/test/basictests.jl b/test/basictests.jl index c31be05fa..54e63e93d 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -53,10 +53,12 @@ end @testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([ 1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES()) + ad isa AutoZygote && continue if prec === :Random prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing) end - sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec, linesearch) + sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec, + linesearch) @test SciMLBase.successful_retcode(sol) @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) @@ -67,25 +69,30 @@ end end if VERSION ≥ v"1.9" - @testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0 - @test begin - res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) - res_true = sqrt(p) - all(res.u .≈ res_true) + @testset "[OOP] [Immutable AD]" begin + for p in 1.0:0.1:100.0 + @test begin + res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) + res_true = sqrt(p) + all(res.u .≈ res_true) + end + @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) end end - @testset "[OOP] [Scalar AD] p: $(p)" for p in 1.0:0.1:100.0 - @test begin - res = benchmark_nlsolve_oop(quadratic_f, 1.0, p) - res_true = sqrt(p) - res.u ≈ res_true + @testset "[OOP] [Scalar AD]" begin + for p in 1.0:0.1:100.0 + @test begin + res = benchmark_nlsolve_oop(quadratic_f, 1.0, p) + res_true = sqrt(p) + res.u ≈ res_true + end + @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, + p) ≈ + 1 / (2 * sqrt(p)) end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p) ≈ - 1 / (2 * sqrt(p)) end if VERSION ≥ v"1.9" @@ -162,33 +169,34 @@ end end if VERSION ≥ v"1.9" - @testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes, - p in 1.0:0.1:100.0 + @testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes + for p in 1.0:0.1:100.0 + @test begin + res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p; + radius_update_scheme) + res_true = sqrt(p) + all(res.u .≈ res_true) + end + @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + @SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) ≈ 1 / (2 * sqrt(p)) + end + end + end + @testset "[OOP] [Scalar AD] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes + for p in 1.0:0.1:100.0 @test begin - res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p; + res = benchmark_nlsolve_oop(quadratic_f, oftype(p, 1.0), p; radius_update_scheme) res_true = sqrt(p) - all(res.u .≈ res_true) + res.u ≈ res_true end @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) ≈ 1 / (2 * sqrt(p)) + oftype(p, 1.0), + p; radius_update_scheme).u, p) ≈ 1 / (2 * sqrt(p)) end end - @testset "[OOP] [Scalar AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes, - p in 1.0:0.1:100.0 - - @test begin - res = benchmark_nlsolve_oop(quadratic_f, oftype(p, 1.0), p; - radius_update_scheme) - res_true = sqrt(p) - res.u ≈ res_true - end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, oftype(p, 1.0), - p; radius_update_scheme).u, p) ≈ 1 / (2 * sqrt(p)) - end - if VERSION ≥ v"1.9" t = (p) -> [sqrt(p[2] / p[1])] p = [0.9, 50.0] @@ -316,25 +324,30 @@ end end if VERSION ≥ v"1.9" - @testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0 - @test begin - res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) - res_true = sqrt(p) - all(res.u .≈ res_true) + @testset "[OOP] [Immutable AD]" begin + for p in 1.0:0.1:100.0 + @test begin + res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) + res_true = sqrt(p) + all(res.u .≈ res_true) + end + @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) end end - @testset "[OOP] [Scalar AD] p: $(p)" for p in 1.0:0.1:100.0 - @test begin - res = benchmark_nlsolve_oop(quadratic_f, 1.0, p) - res_true = sqrt(p) - res.u ≈ res_true + @testset "[OOP] [Scalar AD]" begin + for p in 1.0:0.1:100.0 + @test begin + res = benchmark_nlsolve_oop(quadratic_f, 1.0, p) + res_true = sqrt(p) + res.u ≈ res_true + end + @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, + p) ≈ + 1 / (2 * sqrt(p)) end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p) ≈ - 1 / (2 * sqrt(p)) end if VERSION ≥ v"1.9"