Skip to content

Commit

Permalink
Handle polyalgorithm aliasing correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 21, 2024
1 parent e1c0528 commit 45e0716
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.8.0"
version = "3.8.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf, SciMLBase,
SimpleNonlinearSolve, SparseArrays, SparseDiffTools

import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing
import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing,
ismutable
import DiffEqBase: AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode,
Expand Down
123 changes: 102 additions & 21 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ end
force_stop::Bool
maxiters::Int
internalnorm
u0
u0_aliased
alias_u0::Bool
end

function Base.show(
Expand All @@ -91,11 +94,24 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
@eval begin
function SciMLBase.__init(
prob::$probType, alg::$algType{N}, args...; maxtime = nothing,
maxiters = 1000, internalnorm = DEFAULT_NORM, kwargs...) where {N}
maxiters = 1000, internalnorm = DEFAULT_NORM,
alias_u0 = false, verbose = true, kwargs...) where {N}
if (alias_u0 && !ismutable(prob.u0))
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \

Check warning on line 100 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L100

Added line #L100 was not covered by tests
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing

Check warning on line 102 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L102

Added line #L102 was not covered by tests
end
u0 = prob.u0
if alias_u0
u0_aliased = copy(u0)
else
u0_aliased = u0 # Irrelevant
end
alias_u0 && (prob = remake(prob; u0 = u0_aliased))
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
map(
solver -> SciMLBase.__init(
prob, solver, args...; maxtime, internalnorm, kwargs...),
solver -> SciMLBase.__init(prob, solver, args...; maxtime,
internalnorm, alias_u0, verbose, kwargs...),
alg.algs),
alg,
-1,
Expand All @@ -106,7 +122,10 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
ReturnCode.Default,
false,
maxiters,
internalnorm)
internalnorm,
u0,
u0_aliased,
alias_u0)
end
end
end
Expand All @@ -120,20 +139,30 @@ end

cache_syms = [gensym("cache") for i in 1:N]
sol_syms = [gensym("sol") for i in 1:N]
u_result_syms = [gensym("u_result") for i in 1:N]
for i in 1:N
push!(calls,
quote
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
cache.alias_u0 && copyto!(cache.u0_aliased, cache.u0)
$(sol_syms[i]) = SciMLBase.solve!($(cache_syms[i]))
if SciMLBase.successful_retcode($(sol_syms[i]))
stats = $(sol_syms[i]).stats
u = $(sol_syms[i]).u
if cache.alias_u0
copyto!(cache.u0, $(sol_syms[i]).u)
$(u_result_syms[i]) = cache.u0
else
$(u_result_syms[i]) = $(sol_syms[i]).u
end
fu = get_fu($(cache_syms[i]))
return SciMLBase.build_solution(
$(sol_syms[i]).prob, cache.alg, u, fu;
retcode = $(sol_syms[i]).retcode, stats,
$(sol_syms[i]).prob, cache.alg, $(u_result_syms[i]),
fu; retcode = $(sol_syms[i]).retcode, stats,
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
elseif cache.alias_u0
# For safety we need to maintain a copy of the solution
$(u_result_syms[i]) = copy($(sol_syms[i]).u)
end
cache.current = $(i + 1)
end
Expand All @@ -144,14 +173,29 @@ end
for (sym, resid) in zip(cache_syms, resids)
push!(calls, :($(resid) = @isdefined($(sym)) ? get_fu($(sym)) : nothing))
end
push!(calls, quote
fus = tuple($(Tuple(resids)...))
minfu, idx = __findmin(cache.internalnorm, fus)
stats = __compile_stats(cache.caches[idx])

Check warning on line 179 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L177-L179

Added lines #L177 - L179 were not covered by tests
end)
for i in 1:N
push!(calls, quote
if idx == $(i)
if cache.alias_u0
u = $(u_result_syms[i])

Check warning on line 185 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L183-L185

Added lines #L183 - L185 were not covered by tests
else
u = get_u(cache.caches[$i])

Check warning on line 187 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L187

Added line #L187 was not covered by tests
end
end
end)
end
push!(calls,
quote
fus = tuple($(Tuple(resids)...))
minfu, idx = __findmin(cache.internalnorm, fus)
stats = __compile_stats(cache.caches[idx])
u = get_u(cache.caches[idx])
retcode = cache.caches[idx].retcode

if cache.alias_u0
copyto!(cache.u0, u)
u = cache.u0

Check warning on line 197 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L195-L197

Added lines #L195 - L197 were not covered by tests
end
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx];
retcode, stats, cache.caches[idx].trace)
end)
Expand Down Expand Up @@ -200,22 +244,52 @@ end
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
@generated function SciMLBase.__solve(
prob::$probType, alg::$algType{N}, args...; kwargs...) where {N}
calls = [:(current = alg.start_index)]
@generated function SciMLBase.__solve(prob::$probType, alg::$algType{N}, args...;
alias_u0 = false, verbose = true, kwargs...) where {N}
sol_syms = [gensym("sol") for _ in 1:N]
prob_syms = [gensym("prob") for _ in 1:N]
u_result_syms = [gensym("u_result") for _ in 1:N]
calls = [quote
current = alg.start_index
if (alias_u0 && !ismutable(prob.u0))
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \

Check warning on line 255 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L255

Added line #L255 was not covered by tests
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing

Check warning on line 257 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L257

Added line #L257 was not covered by tests
end
u0 = prob.u0
if alias_u0
u0_aliased = similar(u0)
else
u0_aliased = u0 # Irrelevant
end
end]
for i in 1:N
cur_sol = sol_syms[i]
push!(calls,
quote
if current == $i
$(cur_sol) = SciMLBase.__solve(
prob, alg.algs[$(i)], args...; kwargs...)
if alias_u0
copyto!(u0_aliased, u0)
$(prob_syms[i]) = remake(prob; u0 = u0_aliased)
else
$(prob_syms[i]) = prob
end
$(cur_sol) = SciMLBase.__solve($(prob_syms[i]), alg.algs[$(i)],
args...; alias_u0, verbose, kwargs...)
if SciMLBase.successful_retcode($(cur_sol))
if alias_u0
copyto!(u0, $(cur_sol).u)
$(u_result_syms[i]) = u0
else
$(u_result_syms[i]) = $(cur_sol).u
end
return SciMLBase.build_solution(
prob, alg, $(cur_sol).u, $(cur_sol).resid;
prob, alg, $(u_result_syms[i]), $(cur_sol).resid;
$(cur_sol).retcode, $(cur_sol).stats,
original = $(cur_sol), trace = $(cur_sol).trace)
elseif alias_u0
# For safety we need to maintain a copy of the solution
$(u_result_syms[i]) = copy($(cur_sol).u)
end
current = $(i + 1)
end
Expand All @@ -236,9 +310,16 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
push!(calls,
quote
if idx == $i
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,
$(sol_syms[i]).resid; $(sol_syms[i]).retcode,
$(sol_syms[i]).stats, $(sol_syms[i]).trace)
if alias_u0
copyto!(u0, $(u_result_syms[i]))
$(u_result_syms[i]) = u0

Check warning on line 315 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L314-L315

Added lines #L314 - L315 were not covered by tests
else
$(u_result_syms[i]) = $(sol_syms[i]).u
end
return SciMLBase.build_solution(
prob, alg, $(u_result_syms[i]), $(sol_syms[i]).resid;
$(sol_syms[i]).retcode, $(sol_syms[i]).stats,
$(sol_syms[i]).trace, original = $(sol_syms[i]))
end
end)
end
Expand Down
13 changes: 7 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,16 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x)
@inline __is_complex(::Type{Complex}) = true
@inline __is_complex(::Type{T}) where {T} = false

function __findmin_caches(f, caches)
return __findmin(f get_fu, caches)
end
function __findmin(f, x)
return findmin(x) do xᵢ
@inline __findmin_caches(f, caches) = __findmin(f get_fu, caches)

Check warning on line 97 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L97

Added line #L97 was not covered by tests
# FIXME: DEFAULT_NORM makes an Array of NaNs not a NaN (atleast according to `isnan`)
@inline __findmin(::typeof(DEFAULT_NORM), x) = __findmin(Base.Fix1(maximum, abs), x)
@inline function __findmin(f, x)
fmin = @closure xᵢ -> begin
xᵢ === nothing && return Inf
fx = f(xᵢ)
return isnan(fx) ? Inf : fx
return ifelse(isnan(fx), Inf, fx)
end
return findmin(fmin, x)
end

@inline __can_setindex(x) = can_setindex(x)
Expand Down
25 changes: 25 additions & 0 deletions test/misc/aliasing_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@testitem "PolyAlgorithm Aliasing" begin
using NonlinearProblemLibrary

# Use a problem that the initial solvers cannot solve and cause the initial value to
# diverge. If we don't alias correctly, all the subsequent algorithms will also fail.
prob = NonlinearProblemLibrary.nlprob_23_testcases["Generalized Rosenbrock function"].prob
u0 = copy(prob.u0)
prob = remake(prob; u0 = copy(u0))

# If aliasing is not handled properly this will diverge
sol = solve(prob; abstol = 1e-6, alias_u0 = true,
termination_condition = AbsNormTerminationMode())

@test sol.u === prob.u0
@test SciMLBase.successful_retcode(sol.retcode)

prob = remake(prob; u0 = copy(u0))

cache = init(prob; abstol = 1e-6, alias_u0 = true,
termination_condition = AbsNormTerminationMode())
sol = solve!(cache)

@test sol.u === prob.u0
@test SciMLBase.successful_retcode(sol.retcode)
end

0 comments on commit 45e0716

Please sign in to comment.