Skip to content

Commit

Permalink
Merge pull request #251 from yonatanwesen/yd/pseudotransient
Browse files Browse the repository at this point in the history
Pseudo-Transient Method
  • Loading branch information
ChrisRackauckas authored Oct 20, 2023
2 parents 7fadae1 + 011a815 commit 0256766
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ include("trustRegion.jl")
include("levenberg.jl")
include("gaussnewton.jl")
include("dfsane.jl")
include("pseudotransient.jl")
include("jacobian.jl")
include("ad.jl")
include("default.jl")
Expand Down Expand Up @@ -95,7 +96,7 @@ end

export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient
export LeastSquaresOptimJL, FastLevenbergMarquardtJL
export RobustMultiNewton, FastShortcutNonlinearPolyalg

Expand Down
6 changes: 6 additions & 0 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))

if alg isa PseudoTransient
alpha = convert(eltype(u), alg.alpha_initial)
J_new = J - (1 / alpha) * I
linprob = LinearProblem(J_new, _vec(fu); u0 = _vec(du))
end

weight = similar(u)
recursivefill!(weight, true)

Expand Down
180 changes: 180 additions & 0 deletions src/pseudotransient.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, alpha_initial = 1e-3, adkwargs...)
An implementation of PseudoTransient method that is used to solve steady state problems in an accelerated manner. It uses an adaptive time-stepping to
integrate an initial value of nonlinear problem until sufficient accuracy in the desired steady-state is achieved to switch over to Newton's method and
gain a rapid convergence. This implementation specifically uses "switched evolution relaxation" SER method. For detail information about the time-stepping and algorithm,
please see the paper: [Coffey, Todd S. and Kelley, C. T. and Keyes, David E. (2003), Pseudotransient Continuation and Differential-Algebraic Equations,
SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S106482750241044X)
### Keyword Arguments
- `alpha_initial` : the initial pseudo time step. it defaults to 1e-3. If it is small, you are going to need more iterations to converge.
"""
@concrete struct PseudoTransient{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
linsolve
precs
alpha_initial
end

#concrete_jac(::PseudoTransient{CJ}) where {CJ} = CJ
function set_ad(alg::PseudoTransient{CJ}, ad) where {CJ}
return PseudoTransient{CJ}(ad, alg.linsolve, alg.precs, alg.alpha_initial)
end

function PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, alpha_initial = 1e-3, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return PseudoTransient{_unwrap_val(concrete_jac)}(ad, linsolve, precs, alpha_initial)
end

@concrete mutable struct PseudoTransientCache{iip}
f
alg
u
fu1
fu2
du
p
alpha
res_norm
uf
linsolve
J
jac_cache
force_stop
maxiters::Int
internalnorm
retcode::ReturnCode.T
abstol
prob
stats::NLStats
end

isinplace(::PseudoTransientCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransient,
args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;),
kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)

@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
if iip
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
f(fu1, u, p)
else
fu1 = _mutable(f(u, p))
end
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg,
f,
u,
p,
Val(iip);
linsolve_kwargs)
alpha = convert(eltype(u), alg.alpha_initial)
res_norm = internalnorm(fu1)

return PseudoTransientCache{iip}(f, alg, u, fu1, fu2, du, p, alpha, res_norm, uf,
linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
NLStats(1, 0, 0, 0, 0))
end

function perform_step!(cache::PseudoTransientCache{true})
@unpack u, fu1, f, p, alg, J, linsolve, du, alpha = cache
jacobian!!(J, cache)
J_new = J - (1 / alpha) * I

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J_new, b = _vec(fu1), linu = _vec(du),
p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(fu1, u, p)

new_norm = cache.internalnorm(fu1)
cache.alpha *= cache.res_norm / new_norm
cache.res_norm = new_norm

new_norm < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function perform_step!(cache::PseudoTransientCache{false})
@unpack u, fu1, f, p, alg, linsolve, alpha = cache

cache.J = jacobian!!(cache.J, cache)
# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / (cache.J - (1 / alpha) * I)
else
linres = dolinsolve(alg.precs, linsolve; A = cache.J - (1 / alpha) * I,
b = _vec(fu1),
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
cache.fu1 = f(cache.u, p)

new_norm = cache.internalnorm(fu1)
cache.alpha *= cache.res_norm / new_norm
cache.res_norm = new_norm
new_norm < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function SciMLBase.solve!(cache::PseudoTransientCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
perform_step!(cache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
cache.retcode, cache.stats)
end

function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = cache.p,
alpha_new,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu1, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end
cache.alpha = convert(eltype(cache.u), alpha_new)
cache.res_norm = cache.internalnorm(cache.fu1)
cache.abstol = abstol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
end
121 changes: 121 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,124 @@ end
end
end
end

# --- PseudoTransient tests ---

@testset "PseudoTransient" begin
#these are tests for NewtonRaphson so we should set alpha_initial to be high so that we converge quickly

function benchmark_nlsolve_oop(f, u0, p = 2.0; alpha_initial = 10.0)
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, PseudoTransient(; alpha_initial), abstol = 1e-9)
end

function benchmark_nlsolve_iip(f, u0, p = 2.0; linsolve, precs,
alpha_initial = 10.0)
prob = NonlinearProblem{true}(f, u0, p)
return solve(prob, PseudoTransient(; linsolve, precs, alpha_initial), abstol = 1e-9)
end

@testset "PT: alpha_initial = 10.0 PT AD: $(ad)" for ad in (AutoFiniteDiff(),
AutoZygote())
u0s = VERSION v"1.9" ? ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) : ([1.0, 1.0], 1.0)

@testset "[OOP] u0: $(typeof(u0))" for u0 in u0s
sol = benchmark_nlsolve_oop(quadratic_f, u0)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0),
PseudoTransient(alpha_initial = 10.0),
abstol = 1e-9)
@test (@ballocated solve!($cache)) < 200
end

precs = [NonlinearSolve.DEFAULT_PRECS, :Random]

@testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([
1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES())
ad isa AutoZygote && continue
if prec === :Random
prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing)
end
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
PseudoTransient(; alpha_initial = 10.0, linsolve, precs = prec),
abstol = 1e-9)
@test (@ballocated solve!($cache)) 64
end
end

if VERSION v"1.9"
@testset "[OOP] [Immutable AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end
end
end

@testset "[OOP] [Scalar AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p)
1 / (2 * sqrt(p))
end
end

if VERSION v"1.9"
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p)
ForwardDiff.jacobian(t, p)
end

function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin])
cache = init(probN,
PseudoTransient(alpha_initial = 10.0);
maxiters = 100,
abstol = 1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p, alpha_new = 10.0)
sol = solve!(cache)
sols[i] = iip ? sol.u[1] : sol.u
end
return sols
end
p = range(0.01, 2, length = 200)
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)

@testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true,
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(probN, PseudoTransient(; alpha_initial = 10.0, autodiff)).u .≈
sqrt(2.0))
end

@testset "NewtonRaphson Fails but PT passes" begin # Test that `PseudoTransient` passes a test that `NewtonRaphson` fails on.
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
probN = NonlinearProblem{false}(newton_fails, u0, p)
sol = solve(probN, PseudoTransient(alpha_initial = 1.0), abstol = 1e-10)
@test all(abs.(newton_fails(sol.u, p)) .< 1e-10)
end
end

0 comments on commit 0256766

Please sign in to comment.