From 45e07169472b7c6b4b25cd9acb8cca1e908b2636 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Mar 2024 16:30:41 -0400 Subject: [PATCH] Handle polyalgorithm aliasing correctly --- Project.toml | 2 +- src/NonlinearSolve.jl | 3 +- src/default.jl | 123 ++++++++++++++++++++++++++++++------ src/utils.jl | 13 ++-- test/misc/aliasing_tests.jl | 25 ++++++++ 5 files changed, 137 insertions(+), 29 deletions(-) create mode 100644 test/misc/aliasing_tests.jl diff --git a/Project.toml b/Project.toml index 1294701b1..51aa68f54 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.8.0" +version = "3.8.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index e7fa9bd3e..0cef448d2 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -12,7 +12,8 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf, SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools - import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing + import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing, + ismutable import DiffEqBase: AbstractNonlinearTerminationMode, AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode, diff --git a/src/default.jl b/src/default.jl index cdcfb6825..c1db1638b 100644 --- a/src/default.jl +++ b/src/default.jl @@ -65,6 +65,9 @@ end force_stop::Bool maxiters::Int internalnorm + u0 + u0_aliased + alias_u0::Bool end function Base.show( @@ -91,11 +94,24 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb @eval begin function SciMLBase.__init( prob::$probType, alg::$algType{N}, args...; maxtime = nothing, - maxiters = 1000, internalnorm = DEFAULT_NORM, kwargs...) where {N} + maxiters = 1000, internalnorm = DEFAULT_NORM, + alias_u0 = false, verbose = true, kwargs...) where {N} + if (alias_u0 && !ismutable(prob.u0)) + verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \ + immutable (checked using `ArrayInterface.ismutable`)." + alias_u0 = false # If immutable don't care about aliasing + end + u0 = prob.u0 + if alias_u0 + u0_aliased = copy(u0) + else + u0_aliased = u0 # Irrelevant + end + alias_u0 && (prob = remake(prob; u0 = u0_aliased)) return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}( map( - solver -> SciMLBase.__init( - prob, solver, args...; maxtime, internalnorm, kwargs...), + solver -> SciMLBase.__init(prob, solver, args...; maxtime, + internalnorm, alias_u0, verbose, kwargs...), alg.algs), alg, -1, @@ -106,7 +122,10 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb ReturnCode.Default, false, maxiters, - internalnorm) + internalnorm, + u0, + u0_aliased, + alias_u0) end end end @@ -120,20 +139,30 @@ end cache_syms = [gensym("cache") for i in 1:N] sol_syms = [gensym("sol") for i in 1:N] + u_result_syms = [gensym("u_result") for i in 1:N] for i in 1:N push!(calls, quote $(cache_syms[i]) = cache.caches[$(i)] if $(i) == cache.current + cache.alias_u0 && copyto!(cache.u0_aliased, cache.u0) $(sol_syms[i]) = SciMLBase.solve!($(cache_syms[i])) if SciMLBase.successful_retcode($(sol_syms[i])) stats = $(sol_syms[i]).stats - u = $(sol_syms[i]).u + if cache.alias_u0 + copyto!(cache.u0, $(sol_syms[i]).u) + $(u_result_syms[i]) = cache.u0 + else + $(u_result_syms[i]) = $(sol_syms[i]).u + end fu = get_fu($(cache_syms[i])) return SciMLBase.build_solution( - $(sol_syms[i]).prob, cache.alg, u, fu; - retcode = $(sol_syms[i]).retcode, stats, + $(sol_syms[i]).prob, cache.alg, $(u_result_syms[i]), + fu; retcode = $(sol_syms[i]).retcode, stats, original = $(sol_syms[i]), trace = $(sol_syms[i]).trace) + elseif cache.alias_u0 + # For safety we need to maintain a copy of the solution + $(u_result_syms[i]) = copy($(sol_syms[i]).u) end cache.current = $(i + 1) end @@ -144,14 +173,29 @@ end for (sym, resid) in zip(cache_syms, resids) push!(calls, :($(resid) = @isdefined($(sym)) ? get_fu($(sym)) : nothing)) end + push!(calls, quote + fus = tuple($(Tuple(resids)...)) + minfu, idx = __findmin(cache.internalnorm, fus) + stats = __compile_stats(cache.caches[idx]) + end) + for i in 1:N + push!(calls, quote + if idx == $(i) + if cache.alias_u0 + u = $(u_result_syms[i]) + else + u = get_u(cache.caches[$i]) + end + end + end) + end push!(calls, quote - fus = tuple($(Tuple(resids)...)) - minfu, idx = __findmin(cache.internalnorm, fus) - stats = __compile_stats(cache.caches[idx]) - u = get_u(cache.caches[idx]) retcode = cache.caches[idx].retcode - + if cache.alias_u0 + copyto!(cache.u0, u) + u = cache.u0 + end return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx]; retcode, stats, cache.caches[idx].trace) end) @@ -200,22 +244,52 @@ end for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS)) algType = NonlinearSolvePolyAlgorithm{pType} @eval begin - @generated function SciMLBase.__solve( - prob::$probType, alg::$algType{N}, args...; kwargs...) where {N} - calls = [:(current = alg.start_index)] + @generated function SciMLBase.__solve(prob::$probType, alg::$algType{N}, args...; + alias_u0 = false, verbose = true, kwargs...) where {N} sol_syms = [gensym("sol") for _ in 1:N] + prob_syms = [gensym("prob") for _ in 1:N] + u_result_syms = [gensym("u_result") for _ in 1:N] + calls = [quote + current = alg.start_index + if (alias_u0 && !ismutable(prob.u0)) + verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \ + immutable (checked using `ArrayInterface.ismutable`)." + alias_u0 = false # If immutable don't care about aliasing + end + u0 = prob.u0 + if alias_u0 + u0_aliased = similar(u0) + else + u0_aliased = u0 # Irrelevant + end + end] for i in 1:N cur_sol = sol_syms[i] push!(calls, quote if current == $i - $(cur_sol) = SciMLBase.__solve( - prob, alg.algs[$(i)], args...; kwargs...) + if alias_u0 + copyto!(u0_aliased, u0) + $(prob_syms[i]) = remake(prob; u0 = u0_aliased) + else + $(prob_syms[i]) = prob + end + $(cur_sol) = SciMLBase.__solve($(prob_syms[i]), alg.algs[$(i)], + args...; alias_u0, verbose, kwargs...) if SciMLBase.successful_retcode($(cur_sol)) + if alias_u0 + copyto!(u0, $(cur_sol).u) + $(u_result_syms[i]) = u0 + else + $(u_result_syms[i]) = $(cur_sol).u + end return SciMLBase.build_solution( - prob, alg, $(cur_sol).u, $(cur_sol).resid; + prob, alg, $(u_result_syms[i]), $(cur_sol).resid; $(cur_sol).retcode, $(cur_sol).stats, original = $(cur_sol), trace = $(cur_sol).trace) + elseif alias_u0 + # For safety we need to maintain a copy of the solution + $(u_result_syms[i]) = copy($(cur_sol).u) end current = $(i + 1) end @@ -236,9 +310,16 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb push!(calls, quote if idx == $i - return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u, - $(sol_syms[i]).resid; $(sol_syms[i]).retcode, - $(sol_syms[i]).stats, $(sol_syms[i]).trace) + if alias_u0 + copyto!(u0, $(u_result_syms[i])) + $(u_result_syms[i]) = u0 + else + $(u_result_syms[i]) = $(sol_syms[i]).u + end + return SciMLBase.build_solution( + prob, alg, $(u_result_syms[i]), $(sol_syms[i]).resid; + $(sol_syms[i]).retcode, $(sol_syms[i]).stats, + $(sol_syms[i]).trace, original = $(sol_syms[i])) end end) end diff --git a/src/utils.jl b/src/utils.jl index 259ff945a..d649d334c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -94,15 +94,16 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x) @inline __is_complex(::Type{Complex}) = true @inline __is_complex(::Type{T}) where {T} = false -function __findmin_caches(f, caches) - return __findmin(f ∘ get_fu, caches) -end -function __findmin(f, x) - return findmin(x) do xᵢ +@inline __findmin_caches(f, caches) = __findmin(f ∘ get_fu, caches) +# FIXME: DEFAULT_NORM makes an Array of NaNs not a NaN (atleast according to `isnan`) +@inline __findmin(::typeof(DEFAULT_NORM), x) = __findmin(Base.Fix1(maximum, abs), x) +@inline function __findmin(f, x) + fmin = @closure xᵢ -> begin xᵢ === nothing && return Inf fx = f(xᵢ) - return isnan(fx) ? Inf : fx + return ifelse(isnan(fx), Inf, fx) end + return findmin(fmin, x) end @inline __can_setindex(x) = can_setindex(x) diff --git a/test/misc/aliasing_tests.jl b/test/misc/aliasing_tests.jl new file mode 100644 index 000000000..857233029 --- /dev/null +++ b/test/misc/aliasing_tests.jl @@ -0,0 +1,25 @@ +@testitem "PolyAlgorithm Aliasing" begin + using NonlinearProblemLibrary + + # Use a problem that the initial solvers cannot solve and cause the initial value to + # diverge. If we don't alias correctly, all the subsequent algorithms will also fail. + prob = NonlinearProblemLibrary.nlprob_23_testcases["Generalized Rosenbrock function"].prob + u0 = copy(prob.u0) + prob = remake(prob; u0 = copy(u0)) + + # If aliasing is not handled properly this will diverge + sol = solve(prob; abstol = 1e-6, alias_u0 = true, + termination_condition = AbsNormTerminationMode()) + + @test sol.u === prob.u0 + @test SciMLBase.successful_retcode(sol.retcode) + + prob = remake(prob; u0 = copy(u0)) + + cache = init(prob; abstol = 1e-6, alias_u0 = true, + termination_condition = AbsNormTerminationMode()) + sol = solve!(cache) + + @test sol.u === prob.u0 + @test SciMLBase.successful_retcode(sol.retcode) +end