Skip to content

Commit

Permalink
Merge pull request #28 from utkarsh530/u/buildsolution
Browse files Browse the repository at this point in the history
Add SciML.build_solution
  • Loading branch information
ChrisRackauckas authored Feb 3, 2021
2 parents f232c8f + f081339 commit b8bed15
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 71 deletions.
40 changes: 22 additions & 18 deletions src/scalar.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
fx = float(prob.u0)
T = typeof(x)
atol = xatol !== nothing ? xatol : oneunit(T) * (eps(one(T)))^(4//5)
rtol = xrtol !== nothing ? xrtol : eps(one(T))^(4//5)
Expand All @@ -13,15 +14,15 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, a
fx = f(x)
dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx)
end
iszero(fx) && return NewtonSolution(x, DEFAULT)
iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT))
Δx = dfx \ fx
x -= Δx
if isapprox(x, xo, atol=atol, rtol=rtol)
return NewtonSolution(x, DEFAULT)
return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT))
end
xo = x
end
return NewtonSolution(x, MAXITERS_EXCEED)
return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(MAXITERS_EXCEED))
end

function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
Expand All @@ -32,7 +33,7 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
sol = solve(newprob, alg, args...; kwargs...)

uu = getsolution(sol)
uu = sol.u
if p isa Number
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
else
Expand All @@ -50,39 +51,42 @@ end

function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)

end
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
end

# avoid ambiguities
for Alg in [Bisection]
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials))
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials))
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
end

function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.u0
fl, fr = f(left), f(right)

if iszero(fl)
return BracketingSolution(left, right, EXACT_SOLUTION_LEFT)
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right)
end

i = 1
if !iszero(fr)
while i < maxiters
mid = (left + right) / 2
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
Expand All @@ -101,7 +105,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters

while i < maxiters
mid = (left + right) / 2
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
Expand All @@ -113,23 +117,23 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
i += 1
end

return BracketingSolution(left, right, MAXITERS_EXCEED)
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right)
end

function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...)
function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.u0
fl, fr = f(left), f(right)

if iszero(fl)
return BracketingSolution(left, right, EXACT_SOLUTION_LEFT)
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right)
end

i = 1
if !iszero(fr)
while i < maxiters
if nextfloat_tdir(left, prob.u0...) == right
return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
end
mid = (fr * left - fl * right) / (fr - fl)
for i in 1:10
Expand All @@ -156,7 +160,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10

while i < maxiters
mid = (left + right) / 2
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
Expand All @@ -171,5 +175,5 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10
i += 1
end

return BracketingSolution(left, right, MAXITERS_EXCEED)
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right)
end
26 changes: 7 additions & 19 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ function SciMLBase.solve(prob::NonlinearProblem,
kwargs...)
solver = init(prob, alg, args...; kwargs...)
sol = solve!(solver)
return sol
end

function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
Expand All @@ -30,7 +29,7 @@ function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracket
fl = f(left, p)
fr = f(right, p)
cache = alg_cache(alg, left, right,p, Val(iip))
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip)
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip,prob)
end

function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
Expand All @@ -55,7 +54,7 @@ function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonA
fu = f(u, p)
end
cache = alg_cache(alg, f, u, p, Val(iip))
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip)
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip, prob)
end

function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
Expand All @@ -67,8 +66,11 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
if solver.iter == solver.maxiters
@set! solver.retcode = MAXITERS_EXCEED
end
sol = get_solution(solver)
return sol
if typeof(solver) <: NewtonImmutableSolver
SciMLBase.build_solution(solver.prob, solver.alg, solver.u, solver.fu;retcode=Symbol(solver.retcode))
else
SciMLBase.build_solution(solver.prob, solver.alg, solver.left,solver.fl;retcode=Symbol(solver.retcode),left = solver.left,right = solver.right)
end
end

"""
Expand Down Expand Up @@ -96,20 +98,6 @@ function mic_check(solver::NewtonImmutableSolver)
solver
end

"""
get_solution(solver::Union{BracketingImmutableSolver, BracketingSolver})
get_solution(solver::Union{NewtonImmutableSolver, NewtonSolver})
Form solution object from solver types
"""
function get_solution(solver::BracketingImmutableSolver)
return BracketingSolution(solver.left, solver.right, solver.retcode)
end

function get_solution(solver::NewtonImmutableSolver)
return NewtonSolution(solver.u, solver.retcode)
end

"""
reinit!(solver, prob)
Expand Down
21 changes: 5 additions & 16 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
FLOATING_POINT_LIMIT
end

struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType} <: AbstractImmutableNonlinearSolver
struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, probType} <: AbstractImmutableNonlinearSolver
iter::Int
f::fType
alg::algType
Expand All @@ -20,14 +20,15 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp
retcode::Retcode
cache::cacheType
iip::Bool
prob::probType
end

# function BracketingImmutableSolver(iip, iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache)
# BracketingImmutableSolver{iip, typeof(f), typeof(alg),
# typeof(left), typeof(fl), typeof(p), typeof(cache)}(iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache)
# end

struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType} <: AbstractImmutableNonlinearSolver
struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType, probType} <: AbstractImmutableNonlinearSolver
iter::Int
f::fType
alg::algType
Expand All @@ -41,29 +42,17 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT
tol::tolType
cache::cacheType
iip::Bool
prob::probType
end

# function NewtonImmutableSolver{iip}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) where iip
# NewtonImmutableSolver{iip, typeof(f), typeof(alg), typeof(u),
# typeof(fu), typeof(p), typeof(internalnorm), typeof(tol), typeof(cache)}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache)
# end

struct BracketingSolution{uType}
left::uType
right::uType
retcode::Retcode
end

struct NewtonSolution{uType}
u::uType
retcode::Retcode
end

function sync_residuals!(solver::BracketingImmutableSolver)
@set! solver.fl = solver.f(solver.left, solver.p)
@set! solver.fr = solver.f(solver.right, solver.p)
solver
end

getsolution(sol::NewtonSolution) = sol.u
getsolution(sol::BracketingSolution) = sol.left
end
15 changes: 0 additions & 15 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,6 @@ function num_types_in_tuple(sig::UnionAll)
length(Base.unwrap_unionall(sig).parameters)
end

function numargs(f)
typ = Tuple{Any, Val{:analytic}, Vararg}
typ2 = Tuple{Any, Type{Val{:analytic}}, Vararg} # This one is required for overloaded types
typ3 = Tuple{Any, Val{:jac}, Vararg}
typ4 = Tuple{Any, Type{Val{:jac}}, Vararg} # This one is required for overloaded types
typ5 = Tuple{Any, Val{:tgrad}, Vararg}
typ6 = Tuple{Any, Type{Val{:tgrad}}, Vararg} # This one is required for overloaded types
numparam = maximum([(m.sig<:typ || m.sig<:typ2 || m.sig<:typ3 || m.sig<:typ4 || m.sig<:typ5 || m.sig<:typ6) ? 0 : num_types_in_tuple(m.sig) for m in methods(f)])
return (numparam-1) #-1 in v0.5 since it adds f as the first parameter
end

function isinplace(f,inplace_param_number)
numargs(f)>=inplace_param_number
end

### Default Linsolve

# Try to be as smart as possible
Expand Down
7 changes: 4 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ end
f, u0 = (u,p) -> u .* u .- 2, @SVector[1.0, 1.0]
sf, su0 = (u,p) -> u * u - 2, 1.0
sol = benchmark_immutable(f, u0)
@test sol.retcode === NonlinearSolve.DEFAULT
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_mutable(f, u0)
@test sol.retcode === NonlinearSolve.DEFAULT
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_scalar(sf, su0)
@test sol.retcode === NonlinearSolve.DEFAULT
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test sol.u * sol.u - 2 < 1e-9

@test (@ballocated benchmark_immutable($f, $u0)) == 0
Expand Down Expand Up @@ -117,6 +117,7 @@ probN = NonlinearProblem(f, u0)
@test solve(probN, NewtonRaphson(;autodiff=false); immutable = false).u[end] sqrt(2.0)

for u0 in [1.0, [1, 1.0]]
local f, probN, sol
f = (u, p) -> u .* u .- 2.0
probN = NonlinearProblem(f, u0)
sol = sqrt(2) * u0
Expand Down

0 comments on commit b8bed15

Please sign in to comment.