Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]:Pseudo-Transient method #215

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
57e4eca
reworking pt using avik's cleanup
yonatanwesen Sep 22, 2023
55ba281
Documenter 1.0 upgrade
ArnoStrouwen Sep 23, 2023
d1e0762
Merge pull request #220 from ArnoStrouwen/docs1
ChrisRackauckas Sep 23, 2023
6b2a1ec
forget the 1/alpha term when linsolve is nothing
yonatanwesen Sep 25, 2023
081f3b1
fix incorrect gradient and Gauss-Newton Hessian proxy
FHoltorf Sep 14, 2023
f20e897
fix Cauchy point calculation in dogleg step
FHoltorf Sep 14, 2023
6f3556e
expose quadratic form structure explicitly
FHoltorf Sep 14, 2023
146dec9
add NLsolve trust region updating scheme and change GN step to -J\fu …
FHoltorf Sep 14, 2023
884aafc
add Nocedal and Wright trust region updating scheme
FHoltorf Sep 14, 2023
094eb34
add meaningful description for NLsolve and NW trust region updating s…
FHoltorf Sep 25, 2023
0e99655
cache memory for cauchy step to enable non-allocating code
FHoltorf Sep 25, 2023
439415b
parameter types should not be converted to eltype(u). For now, defaul…
FHoltorf Sep 26, 2023
0ba652b
finish rebase to master
FHoltorf Sep 26, 2023
6f7ef29
introduce better variable names (and also ones that are more consiste…
FHoltorf Sep 26, 2023
79ab992
fix consistency test
FHoltorf Sep 26, 2023
b749501
fix oop perform step and Fan scheme initialization
FHoltorf Sep 26, 2023
21d246d
improve comment
FHoltorf Sep 26, 2023
177def2
use less conservative step acceptance policy
FHoltorf Sep 26, 2023
0190764
choose Float64 as default type for trust region adaptation parameters…
FHoltorf Sep 26, 2023
5f298e3
hardcode NLsolve parameters
FHoltorf Sep 26, 2023
b2b5d89
add NLsolve-like trust region initialization
FHoltorf Sep 27, 2023
f67ced5
avoid recomputation of GN step if TR step was rejected. Faster and av…
FHoltorf Sep 27, 2023
34821e5
run SciML formatter
FHoltorf Sep 27, 2023
609f67c
rename NW -> NocedalWright
FHoltorf Sep 27, 2023
32cf735
convergence check for NocedalWright
FHoltorf Sep 27, 2023
e77fa76
test new trust region schemes
FHoltorf Sep 27, 2023
f5c66c4
run formatter
FHoltorf Sep 27, 2023
3a750aa
Merge pull request #210 from FHoltorf/tr
ChrisRackauckas Sep 29, 2023
6607fc6
Update Project.toml
ChrisRackauckas Sep 29, 2023
72a95e5
Update Project.toml
ChrisRackauckas Sep 29, 2023
670a4bf
remove precompile on v1.9
ChrisRackauckas Oct 2, 2023
f93cd7a
Update src/NonlinearSolve.jl
ChrisRackauckas Oct 2, 2023
41cea21
Merge pull request #224 from SciML/precompile
ChrisRackauckas Oct 2, 2023
fd4ae4b
Update Project.toml
ChrisRackauckas Oct 2, 2023
9b6ddf1
added linsolve kwargs so that simple gmres stops
yonatanwesen Oct 4, 2023
47f47bd
reworking pt using avik's cleanup
yonatanwesen Sep 22, 2023
5e2f5a1
forget the 1/alpha term when linsolve is nothing
yonatanwesen Sep 25, 2023
6955074
added linsolve kwargs so that simple gmres stops
yonatanwesen Oct 4, 2023
c26577a
Merge branch 'yd/pseudo-transient' of github.com:yonatanwesen/Nonline…
yonatanwesen Oct 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ include("linesearch.jl")
include("raphson.jl")
include("trustRegion.jl")
include("levenberg.jl")
include("pseudotransient.jl")
include("jacobian.jl")
include("ad.jl")

Expand All @@ -49,7 +50,7 @@ PrecompileTools.@compile_workload begin
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

precompile_algs = if VERSION ≥ v"1.7"
(NewtonRaphson(), TrustRegion(), LevenbergMarquardt())
(NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient())
else
(NewtonRaphson(),)
end
Expand All @@ -68,7 +69,7 @@ end

export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt
export NewtonRaphson, TrustRegion, LevenbergMarquardt, PseudoTransient

export LineSearch

Expand Down
8 changes: 7 additions & 1 deletion src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,13 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
end

du = _mutable_zero(u)
linprob = LinearProblem(J, _vec(fu); u0 = _vec(du))
if alg isa PseudoTransient
alpha = convert(eltype(u), alg.alpha_initial)
J_new = J - (1 / alpha) * I
linprob = LinearProblem(J_new, _vec(fu); u0 = _vec(du))
else
linprob = LinearProblem(J, _vec(fu); u0 = _vec(du))
end

weight = similar(u)
recursivefill!(weight, true)
Expand Down
150 changes: 150 additions & 0 deletions src/pseudotransient.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
@concrete struct PseudoTransient{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
linsolve
precs
alpha_initial
end

concrete_jac(::PseudoTransient{CJ}) where {CJ} = CJ

function PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, alpha_initial = 1e-3, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return PseudoTransient{_unwrap_val(concrete_jac)}(ad, linsolve, precs, alpha_initial)
end

@concrete mutable struct PseudoTransientCache{iip}
f
alg
u
fu1
fu2
du
p
alpha
res_norm
uf
linsolve
J
jac_cache
force_stop
maxiters::Int
internalnorm
retcode::ReturnCode.T
abstol
prob
stats::NLStats
end

isinplace(::PseudoTransientCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::PseudoTransient, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
if iip
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
f(fu1, u, p)
else
fu1 = _mutable(f(u, p))
end
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))
alpha = convert(eltype(u), alg.alpha_initial)
res_norm = internalnorm(fu1)

return PseudoTransientCache{iip}(f, alg, u, fu1, fu2, du, p, alpha, res_norm, uf,
linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
NLStats(1, 0, 0, 0, 0))
end

function perform_step!(cache::PseudoTransientCache{true})
@unpack u, fu1, f, p, alg, J, linsolve, du, alpha = cache
jacobian!!(J, cache)
J_new = J - (1 / alpha) * I

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J_new, b = _vec(fu1), linu = _vec(du),
p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(fu1, u, p)

new_norm = cache.internalnorm(fu1)
cache.alpha *= cache.res_norm / new_norm
cache.res_norm = new_norm

new_norm < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function perform_step!(cache::PseudoTransientCache{false})
@unpack u, fu1, f, p, alg, linsolve, alpha = cache

cache.J = jacobian!!(cache.J, cache)
# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J
yonatanwesen marked this conversation as resolved.
Show resolved Hide resolved
else
linres = dolinsolve(alg.precs, linsolve; A = cache.J - (1 / alpha) * I,
b = _vec(fu1),
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
cache.fu1 = f(cache.u, p)

new_norm = cache.internalnorm(fu1)
cache.alpha *= cache.res_norm / new_norm
cache.res_norm = new_norm

new_norm < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function SciMLBase.solve!(cache::PseudoTransientCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
perform_step!(cache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
cache.retcode, cache.stats)
end

function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu1, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end
cache.alpha = convert(eltype(u), 1e-3)
cache.res_norm = internalnorm(cache.fu1)
cache.abstol = abstol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
end
57 changes: 57 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,60 @@ end
end
end
end

# --- PseudoTransient tests ---

@testset "PseudoTransient" begin

#iip
function test_f!(du, u, p)
du[1] = 2 - 2u[1]
du[2] = u[1] - 4u[2]
return du
end

#oop
simple_test(u, p) = 0.9f0 .* u .- u

#test jacobian free PseudoTransient method
function f!(res, u, p)
@. res = u * u - p
end

@testset "[IIP] u0: $(typeof(u0))" for u0 in (zeros(2),)
probN = NonlinearProblem{true}(test_f!, u0)
sol = solve(probN, PseudoTransient())
@test sol.retcode == ReturnCode.Success
du = zeros(2)
test_f!(du, sol.u, 4.0)
@test du≈[0, 0] atol=1e-7
end

@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0f0],)
probN = NonlinearProblem{false}(simple_test, u0)
sol = solve(probN, PseudoTransient(alpha_initial = 1.0))
@test sol.retcode == ReturnCode.Success
@test abs(sol.u[1]) <= 1.0f-4
end

@testset "Jacobian Free PT" begin
u0 = [1.0, 1.0]
p = 2.0
probN = NonlinearProblem(f!, u0, p)
linsolve = LinearSolve.KrylovJL_GMRES()
sol = solve(probN,
PseudoTransient(alpha_initial = 10.0, linsolve = linsolve),
reltol = 1e-9)
@test sol.retcode == ReturnCode.Success
end

@testset "NewtonRaphson Fails" begin
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
probN = NonlinearProblem{false}(newton_fails, u0, p)
sol = solve(probN, PseudoTransient(alpha_initial = 1.0), abstol = 1e-10)
@test all(abs.(newton_fails(sol.u, p)) .< 1e-10)
end
end

# Test that `PseudoTransient` passes a test that `NewtonRaphson` fails on.