Skip to content

Commit

Permalink
Merge pull request #392 from SciML/ap/reliable_aliasing
Browse files Browse the repository at this point in the history
Handle polyalgorithm aliasing correctly
  • Loading branch information
ChrisRackauckas authored Mar 21, 2024
2 parents e1c0528 + 45e0716 commit 3925219
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.8.0"
version = "3.8.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf, SciMLBase,
SimpleNonlinearSolve, SparseArrays, SparseDiffTools

import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing
import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing,
ismutable
import DiffEqBase: AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode,
Expand Down
123 changes: 102 additions & 21 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ end
force_stop::Bool
maxiters::Int
internalnorm
u0
u0_aliased
alias_u0::Bool
end

function Base.show(
Expand All @@ -91,11 +94,24 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
@eval begin
function SciMLBase.__init(
prob::$probType, alg::$algType{N}, args...; maxtime = nothing,
maxiters = 1000, internalnorm = DEFAULT_NORM, kwargs...) where {N}
maxiters = 1000, internalnorm = DEFAULT_NORM,
alias_u0 = false, verbose = true, kwargs...) where {N}
if (alias_u0 && !ismutable(prob.u0))
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing
end
u0 = prob.u0
if alias_u0
u0_aliased = copy(u0)
else
u0_aliased = u0 # Irrelevant
end
alias_u0 && (prob = remake(prob; u0 = u0_aliased))
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
map(
solver -> SciMLBase.__init(
prob, solver, args...; maxtime, internalnorm, kwargs...),
solver -> SciMLBase.__init(prob, solver, args...; maxtime,
internalnorm, alias_u0, verbose, kwargs...),
alg.algs),
alg,
-1,
Expand All @@ -106,7 +122,10 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
ReturnCode.Default,
false,
maxiters,
internalnorm)
internalnorm,
u0,
u0_aliased,
alias_u0)
end
end
end
Expand All @@ -120,20 +139,30 @@ end

cache_syms = [gensym("cache") for i in 1:N]
sol_syms = [gensym("sol") for i in 1:N]
u_result_syms = [gensym("u_result") for i in 1:N]
for i in 1:N
push!(calls,
quote
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
cache.alias_u0 && copyto!(cache.u0_aliased, cache.u0)
$(sol_syms[i]) = SciMLBase.solve!($(cache_syms[i]))
if SciMLBase.successful_retcode($(sol_syms[i]))
stats = $(sol_syms[i]).stats
u = $(sol_syms[i]).u
if cache.alias_u0
copyto!(cache.u0, $(sol_syms[i]).u)
$(u_result_syms[i]) = cache.u0
else
$(u_result_syms[i]) = $(sol_syms[i]).u
end
fu = get_fu($(cache_syms[i]))
return SciMLBase.build_solution(
$(sol_syms[i]).prob, cache.alg, u, fu;
retcode = $(sol_syms[i]).retcode, stats,
$(sol_syms[i]).prob, cache.alg, $(u_result_syms[i]),
fu; retcode = $(sol_syms[i]).retcode, stats,
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
elseif cache.alias_u0
# For safety we need to maintain a copy of the solution
$(u_result_syms[i]) = copy($(sol_syms[i]).u)
end
cache.current = $(i + 1)
end
Expand All @@ -144,14 +173,29 @@ end
for (sym, resid) in zip(cache_syms, resids)
push!(calls, :($(resid) = @isdefined($(sym)) ? get_fu($(sym)) : nothing))
end
push!(calls, quote
fus = tuple($(Tuple(resids)...))
minfu, idx = __findmin(cache.internalnorm, fus)
stats = __compile_stats(cache.caches[idx])
end)
for i in 1:N
push!(calls, quote
if idx == $(i)
if cache.alias_u0
u = $(u_result_syms[i])
else
u = get_u(cache.caches[$i])
end
end
end)
end
push!(calls,
quote
fus = tuple($(Tuple(resids)...))
minfu, idx = __findmin(cache.internalnorm, fus)
stats = __compile_stats(cache.caches[idx])
u = get_u(cache.caches[idx])
retcode = cache.caches[idx].retcode

if cache.alias_u0
copyto!(cache.u0, u)
u = cache.u0
end
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx];
retcode, stats, cache.caches[idx].trace)
end)
Expand Down Expand Up @@ -200,22 +244,52 @@ end
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
@generated function SciMLBase.__solve(
prob::$probType, alg::$algType{N}, args...; kwargs...) where {N}
calls = [:(current = alg.start_index)]
@generated function SciMLBase.__solve(prob::$probType, alg::$algType{N}, args...;
alias_u0 = false, verbose = true, kwargs...) where {N}
sol_syms = [gensym("sol") for _ in 1:N]
prob_syms = [gensym("prob") for _ in 1:N]
u_result_syms = [gensym("u_result") for _ in 1:N]
calls = [quote
current = alg.start_index
if (alias_u0 && !ismutable(prob.u0))
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing
end
u0 = prob.u0
if alias_u0
u0_aliased = similar(u0)
else
u0_aliased = u0 # Irrelevant
end
end]
for i in 1:N
cur_sol = sol_syms[i]
push!(calls,
quote
if current == $i
$(cur_sol) = SciMLBase.__solve(
prob, alg.algs[$(i)], args...; kwargs...)
if alias_u0
copyto!(u0_aliased, u0)
$(prob_syms[i]) = remake(prob; u0 = u0_aliased)
else
$(prob_syms[i]) = prob
end
$(cur_sol) = SciMLBase.__solve($(prob_syms[i]), alg.algs[$(i)],
args...; alias_u0, verbose, kwargs...)
if SciMLBase.successful_retcode($(cur_sol))
if alias_u0
copyto!(u0, $(cur_sol).u)
$(u_result_syms[i]) = u0
else
$(u_result_syms[i]) = $(cur_sol).u
end
return SciMLBase.build_solution(
prob, alg, $(cur_sol).u, $(cur_sol).resid;
prob, alg, $(u_result_syms[i]), $(cur_sol).resid;
$(cur_sol).retcode, $(cur_sol).stats,
original = $(cur_sol), trace = $(cur_sol).trace)
elseif alias_u0
# For safety we need to maintain a copy of the solution
$(u_result_syms[i]) = copy($(cur_sol).u)
end
current = $(i + 1)
end
Expand All @@ -236,9 +310,16 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
push!(calls,
quote
if idx == $i
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,
$(sol_syms[i]).resid; $(sol_syms[i]).retcode,
$(sol_syms[i]).stats, $(sol_syms[i]).trace)
if alias_u0
copyto!(u0, $(u_result_syms[i]))
$(u_result_syms[i]) = u0
else
$(u_result_syms[i]) = $(sol_syms[i]).u
end
return SciMLBase.build_solution(
prob, alg, $(u_result_syms[i]), $(sol_syms[i]).resid;
$(sol_syms[i]).retcode, $(sol_syms[i]).stats,
$(sol_syms[i]).trace, original = $(sol_syms[i]))
end
end)
end
Expand Down
13 changes: 7 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,16 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x)
@inline __is_complex(::Type{Complex}) = true
@inline __is_complex(::Type{T}) where {T} = false

function __findmin_caches(f, caches)
return __findmin(f get_fu, caches)
end
function __findmin(f, x)
return findmin(x) do xᵢ
@inline __findmin_caches(f, caches) = __findmin(f get_fu, caches)
# FIXME: DEFAULT_NORM makes an Array of NaNs not a NaN (atleast according to `isnan`)
@inline __findmin(::typeof(DEFAULT_NORM), x) = __findmin(Base.Fix1(maximum, abs), x)
@inline function __findmin(f, x)
fmin = @closure xᵢ -> begin
xᵢ === nothing && return Inf
fx = f(xᵢ)
return isnan(fx) ? Inf : fx
return ifelse(isnan(fx), Inf, fx)
end
return findmin(fmin, x)
end

@inline __can_setindex(x) = can_setindex(x)
Expand Down
25 changes: 25 additions & 0 deletions test/misc/aliasing_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@testitem "PolyAlgorithm Aliasing" begin
using NonlinearProblemLibrary

# Use a problem that the initial solvers cannot solve and cause the initial value to
# diverge. If we don't alias correctly, all the subsequent algorithms will also fail.
prob = NonlinearProblemLibrary.nlprob_23_testcases["Generalized Rosenbrock function"].prob
u0 = copy(prob.u0)
prob = remake(prob; u0 = copy(u0))

# If aliasing is not handled properly this will diverge
sol = solve(prob; abstol = 1e-6, alias_u0 = true,
termination_condition = AbsNormTerminationMode())

@test sol.u === prob.u0
@test SciMLBase.successful_retcode(sol.retcode)

prob = remake(prob; u0 = copy(u0))

cache = init(prob; abstol = 1e-6, alias_u0 = true,
termination_condition = AbsNormTerminationMode())
sol = solve!(cache)

@test sol.u === prob.u0
@test SciMLBase.successful_retcode(sol.retcode)
end

2 comments on commit 3925219

@avik-pal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/103407

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.8.1 -m "<description of version>" 39252192054635732341742aa14999d275f3d0e9
git push origin v3.8.1

Please sign in to comment.