Skip to content

Commit

Permalink
add NLsolve trust region updating scheme and change GN step to -J\fu …
Browse files Browse the repository at this point in the history
…to avoid growing ill-conditioning
  • Loading branch information
FHoltorf committed Sep 14, 2023
1 parent 332a3bf commit 065aee6
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ EnumX.@enumx RadiusUpdateSchemes begin
"""
Simple

"""
`RadiusUpdateSchemes.NLsolve`
The same updating rule as in NLsolve's trust region implementation
"""
NLsolve

"""
`RadiusUpdateSchemes.Hei`
Expand Down Expand Up @@ -177,8 +184,8 @@ function TrustRegion(; chunk_size = Val{0}(),
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10000,
shrink_threshold::Real = 1 // 4,
expand_threshold::Real = 3 // 4,
shrink_threshold::Real = 1 // 10, #1 // 4,
expand_threshold::Real = 9 // 10, #3 // 4,
shrink_factor::Real = 1 // 4,
expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32)
Expand Down Expand Up @@ -340,7 +347,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
p3 = convert(eltype(u), 0.0)
p4 = convert(eltype(u), 0.0)
ϵ = convert(eltype(u), 1.0e-8)
if radius_update_scheme === RadiusUpdateSchemes.Hei
if radius_update_scheme === RadiusUpdateSchemes.NLsolve
p1 = convert(eltype(u), 0.5)
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
step_threshold = convert(eltype(u), 0.0)
shrink_threshold = convert(eltype(u), 0.25)
expand_threshold = convert(eltype(u), 0.25)
Expand Down Expand Up @@ -407,7 +416,7 @@ function perform_step!(cache::TrustRegionCache{true})
cache.stats.njacs += 1
end

linres = dolinsolve(alg.precs, linsolve, A = cache.H, b = _vec(cache.g),
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu), # cache.H, b = _vec(cache.g),
linu = _vec(u_tmp),
p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
Expand Down Expand Up @@ -479,7 +488,7 @@ function trust_region_step!(cache::TrustRegionCache)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) / (dot(step_size, g) + dot(step_size, H, step_size) / 2)
@unpack r = cache
@unpack r = cache

if radius_update_scheme === RadiusUpdateSchemes.Simple
# Update the trust region radius.
Expand Down Expand Up @@ -508,6 +517,30 @@ function trust_region_step!(cache::TrustRegionCache)
cache.force_stop = true
end

elseif radius_update_scheme === RadiusUpdateSchemes.NLsolve
# accept/reject decision
if r > cache.step_threshold # accept
take_step!(cache)
cache.loss = cache.loss_new
cache.make_new_J = true
else # reject
cache.make_new_J = false
end

# trust region update
if r < cache.shrink_threshold # default 1 // 10
cache.trust_r *= cache.shrink_factor # default 1 // 2
elseif r >= cache.expand_threshold # default 9 // 10
cache.trust_r = cache.expand_factor * norm(cache.step_size) # default 2
elseif r >= cache.p1 # default 1 // 2
cache.trust_r = max(cache.trust_r, cache.expand_factor * norm(cache.step_size))
end

# convergence test
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end

elseif radius_update_scheme === RadiusUpdateSchemes.Hei
if r > cache.step_threshold
take_step!(cache)
Expand Down

0 comments on commit 065aee6

Please sign in to comment.