Skip to content

Commit 517f695

Browse files
committed
Use a descent result
1 parent 9e46f00 commit 517f695

13 files changed

+163
-116
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.5.4"
4+
version = "3.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/devdocs/internal_interfaces.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ NonlinearSolve.AbstractNonlinearSolveCache
1313
```@docs
1414
NonlinearSolve.AbstractDescentAlgorithm
1515
NonlinearSolve.AbstractDescentCache
16+
NonlinearSolve.DescentResult
1617
```
1718

1819
## Approximate Jacobian

src/NonlinearSolve.jl

+2
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ include("adtypes.jl")
4444
include("timer_outputs.jl")
4545
include("internal/helpers.jl")
4646

47+
include("descent/common.jl")
4748
include("descent/newton.jl")
4849
include("descent/steepest.jl")
4950
include("descent/dogleg.jl")
5051
include("descent/damped_newton.jl")
5152
include("descent/geodesic_acceleration.jl")
53+
include("descent/multistep.jl")
5254

5355
include("internal/operators.jl")
5456
include("internal/jacobian.jl")

src/abstract_types.jl

+3-10
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ Abstract Type for all Descent Caches.
6565
### `__internal_solve!` specification
6666
6767
```julia
68-
δu, success, intermediates = __internal_solve!(cache::AbstractDescentCache, J, fu, u,
69-
idx::Val; skip_solve::Bool = false, kwargs...)
68+
descent_result = __internal_solve!(cache::AbstractDescentCache, J, fu, u, idx::Val;
69+
skip_solve::Bool = false, kwargs...)
7070
```
7171
7272
- `J`: Jacobian or Inverse Jacobian (if `pre_inverted = Val(true)`).
@@ -78,14 +78,7 @@ Abstract Type for all Descent Caches.
7878
direction was rejected and we want to try with a modified trust region.
7979
- `kwargs`: keyword arguments to pass to the linear solver if there is one.
8080
81-
#### Returned values
82-
83-
- `δu`: the descent direction.
84-
- `success`: Certain Descent Algorithms can reject a descent direction for example
85-
`GeodesicAcceleration`.
86-
- `intermediates`: A named tuple containing intermediates computed during the solve.
87-
For example, `GeodesicAcceleration` returns `NamedTuple{(:v, :a)}` containing the
88-
"velocity" and "acceleration" terms.
81+
Returns a result of type [`DescentResult`](@ref).
8982
9083
### Interface Functions
9184

src/core/approximate_jacobian.jl

+52-42
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
163163

164164
linsolve = get_linear_solver(alg.descent)
165165
initialization_cache = __internal_init(prob, alg.initialization, alg, f, fu, u, p;
166-
linsolve,
167-
maxiters, internalnorm)
166+
linsolve, maxiters, internalnorm)
168167

169168
abstol, reltol, termination_cache = init_termination_cache(abstol, reltol, fu, u,
170169
termination_condition)
@@ -222,9 +221,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
222221
new_jacobian = true
223222
@static_timeit cache.timer "jacobian init/reinit" begin
224223
if get_nsteps(cache) == 0 # First Step is special ignore kwargs
225-
J_init = __internal_solve!(cache.initialization_cache,
226-
cache.fu,
227-
cache.u,
224+
J_init = __internal_solve!(cache.initialization_cache, cache.fu, cache.u,
228225
Val(false))
229226
if INV
230227
if jacobian_initialized_preinverted(cache.initialization_cache.alg)
@@ -283,52 +280,65 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
283280
@static_timeit cache.timer "descent" begin
284281
if cache.trustregion_cache !== nothing &&
285282
hasfield(typeof(cache.trustregion_cache), :trust_region)
286-
δu, descent_success, descent_intermediates = __internal_solve!(cache.descent_cache,
287-
J, cache.fu, cache.u; new_jacobian,
288-
trust_region = cache.trustregion_cache.trust_region)
283+
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
284+
new_jacobian, trust_region = cache.trustregion_cache.trust_region)
289285
else
290-
δu, descent_success, descent_intermediates = __internal_solve!(cache.descent_cache,
291-
J, cache.fu, cache.u; new_jacobian)
286+
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
287+
new_jacobian)
292288
end
293289
end
294290

295-
if descent_success
296-
if GB === :LineSearch
297-
@static_timeit cache.timer "linesearch" begin
298-
needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
299-
end
300-
if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period
301-
cache.force_reinit = true
302-
else
303-
@static_timeit cache.timer "step" begin
304-
@bb axpy!(α, δu, cache.u)
305-
evaluate_f!(cache, cache.u, cache.p)
306-
end
307-
end
308-
elseif GB === :TrustRegion
309-
@static_timeit cache.timer "trustregion" begin
310-
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache, J,
311-
cache.fu, cache.u, δu, descent_intermediates)
312-
if tr_accepted
313-
@bb copyto!(cache.u, u_new)
314-
@bb copyto!(cache.fu, fu_new)
315-
end
316-
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
317-
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
318-
cache.retcode = ReturnCode.ShrinkThresholdExceeded
319-
cache.force_stop = true
320-
end
321-
end
322-
α = true
323-
elseif GB === :None
291+
if descent_result.success
292+
if GB === :None
324293
@static_timeit cache.timer "step" begin
325-
@bb axpy!(1, δu, cache.u)
294+
if descent_result.u !== missing
295+
@bb copyto!(cache.u, descent_result.u)
296+
elseif descent_result.δu !== missing
297+
@bb axpy!(1, descent_result.δu, cache.u)
298+
else
299+
error("This shouldn't occur. `$(cache.alg.descent)` is incorrectly \
300+
specified.")
301+
end
326302
evaluate_f!(cache, cache.u, cache.p)
327303
end
328304
α = true
329305
else
330-
error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \
331-
:TrustRegion, :None)")
306+
δu = descent_result.δu
307+
@assert δu!==missing "Descent Supporting LineSearch or TrustRegion must return a `δu`."
308+
309+
if GB === :LineSearch
310+
@static_timeit cache.timer "linesearch" begin
311+
needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
312+
end
313+
if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period
314+
cache.force_reinit = true
315+
else
316+
@static_timeit cache.timer "step" begin
317+
@bb axpy!(α, δu, cache.u)
318+
evaluate_f!(cache, cache.u, cache.p)
319+
end
320+
end
321+
elseif GB === :TrustRegion
322+
@static_timeit cache.timer "trustregion" begin
323+
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache,
324+
J, cache.fu, cache.u, δu, descent_result.extras)
325+
if tr_accepted
326+
@bb copyto!(cache.u, u_new)
327+
@bb copyto!(cache.fu, fu_new)
328+
α = true
329+
else
330+
α = false
331+
end
332+
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
333+
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
334+
cache.retcode = ReturnCode.ShrinkThresholdExceeded
335+
cache.force_stop = true
336+
end
337+
end
338+
else
339+
error("Unknown Globalization Strategy: $(GB). Allowed values are \
340+
(:LineSearch, :TrustRegion, :None)")
341+
end
332342
end
333343
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
334344
else

src/core/generalized_first_order.jl

+49-39
Original file line numberDiff line numberDiff line change
@@ -215,57 +215,67 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
215215
@static_timeit cache.timer "descent" begin
216216
if cache.trustregion_cache !== nothing &&
217217
hasfield(typeof(cache.trustregion_cache), :trust_region)
218-
δu, descent_success, descent_intermediates = __internal_solve!(cache.descent_cache,
219-
J, cache.fu, cache.u; new_jacobian,
220-
trust_region = cache.trustregion_cache.trust_region)
218+
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
219+
new_jacobian, trust_region = cache.trustregion_cache.trust_region)
221220
else
222-
δu, descent_success, descent_intermediates = __internal_solve!(cache.descent_cache,
223-
J, cache.fu, cache.u; new_jacobian)
221+
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
222+
new_jacobian)
224223
end
225224
end
226225

227-
if descent_success
226+
if descent_result.success
228227
cache.make_new_jacobian = true
229-
if GB === :LineSearch
230-
@static_timeit cache.timer "linesearch" begin
231-
linesearch_failed, α = __internal_solve!(cache.linesearch_cache,
232-
cache.u, δu)
233-
end
234-
if linesearch_failed
235-
cache.retcode = ReturnCode.InternalLineSearchFailed
236-
cache.force_stop = true
237-
end
228+
if GB === :None
238229
@static_timeit cache.timer "step" begin
239-
@bb axpy!(α, δu, cache.u)
240-
evaluate_f!(cache, cache.u, cache.p)
241-
end
242-
elseif GB === :TrustRegion
243-
@static_timeit cache.timer "trustregion" begin
244-
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache, J,
245-
cache.fu, cache.u, δu, descent_intermediates)
246-
if tr_accepted
247-
@bb copyto!(cache.u, u_new)
248-
@bb copyto!(cache.fu, fu_new)
249-
α = true
230+
if descent_result.u !== missing
231+
@bb copyto!(cache.u, descent_result.u)
232+
elseif descent_result.δu !== missing
233+
@bb axpy!(1, descent_result.δu, cache.u)
250234
else
251-
α = false
252-
cache.make_new_jacobian = false
235+
error("This shouldn't occur. `$(cache.alg.descent)` is incorrectly \
236+
specified.")
253237
end
254-
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
255-
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
256-
cache.retcode = ReturnCode.ShrinkThresholdExceeded
257-
cache.force_stop = true
258-
end
259-
end
260-
elseif GB === :None
261-
@static_timeit cache.timer "step" begin
262-
@bb axpy!(1, δu, cache.u)
263238
evaluate_f!(cache, cache.u, cache.p)
264239
end
265240
α = true
266241
else
267-
error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \
268-
:TrustRegion, :None)")
242+
δu = descent_result.δu
243+
@assert δu!==missing "Descent Supporting LineSearch or TrustRegion must return a `δu`."
244+
245+
if GB === :LineSearch
246+
@static_timeit cache.timer "linesearch" begin
247+
failed, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
248+
end
249+
if failed
250+
cache.retcode = ReturnCode.InternalLineSearchFailed
251+
cache.force_stop = true
252+
else
253+
@static_timeit cache.timer "step" begin
254+
@bb axpy!(α, δu, cache.u)
255+
evaluate_f!(cache, cache.u, cache.p)
256+
end
257+
end
258+
elseif GB === :TrustRegion
259+
@static_timeit cache.timer "trustregion" begin
260+
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache,
261+
J, cache.fu, cache.u, δu, descent_result.extras)
262+
if tr_accepted
263+
@bb copyto!(cache.u, u_new)
264+
@bb copyto!(cache.fu, fu_new)
265+
α = true
266+
else
267+
α = false
268+
end
269+
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
270+
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
271+
cache.retcode = ReturnCode.ShrinkThresholdExceeded
272+
cache.force_stop = true
273+
end
274+
end
275+
else
276+
error("Unknown Globalization Strategy: $(GB). Allowed values are \
277+
(:LineSearch, :TrustRegion, :None)")
278+
end
269279
end
270280
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
271281
else

src/descent/common.jl

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;))
3+
4+
Construct a `DescentResult` object.
5+
6+
### Keyword Arguments
7+
8+
* `δu`: The descent direction.
9+
* `u`: The new iterate. This is provided only for multi-step methods currently.
10+
* `success`: Certain Descent Algorithms can reject a descent direction for example
11+
[`GeodesicAcceleration`](@ref).
12+
* `extras`: A named tuple containing intermediates computed during the solve.
13+
For example, [`GeodesicAcceleration`](@ref) returns `NamedTuple{(:v, :a)}` containing
14+
the "velocity" and "acceleration" terms.
15+
"""
16+
@concrete struct DescentResult
17+
δu
18+
u
19+
success::Bool
20+
extras
21+
end
22+
23+
function DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;))
24+
@assert δu !== missing || u !== missing
25+
return DescentResult(δu, u, success, extras)
26+
end

src/descent/damped_newton.jl

+4-5
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ function __internal_solve!(cache::DampedNewtonDescentCache{INV, mode}, J, fu, u,
136136
idx::Val{N} = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true,
137137
kwargs...) where {INV, N, mode}
138138
δu = get_du(cache, idx)
139-
skip_solve && return δu, true, (;)
139+
skip_solve && return DescentResult(; δu)
140140

141141
recompute_A = idx === Val(1)
142142

@@ -201,15 +201,14 @@ function __internal_solve!(cache::DampedNewtonDescentCache{INV, mode}, J, fu, u,
201201
end
202202

203203
@static_timeit cache.timer "linear solve" begin
204-
δu = cache.lincache(; A, b,
205-
reuse_A_if_factorization = !new_jacobian && !recompute_A,
206-
kwargs..., linu = _vec(δu))
204+
δu = cache.lincache(; A, b, linu = _vec(δu),
205+
reuse_A_if_factorization = !new_jacobian && !recompute_A, kwargs...)
207206
δu = _restructure(get_du(cache, idx), δu)
208207
end
209208

210209
@bb @. δu *= -1
211210
set_du!(cache, δu, idx)
212-
return δu, true, (;)
211+
return DescentResult(; δu)
213212
end
214213

215214
# Define special concatenation for certain Array combinations

0 commit comments

Comments
 (0)