Skip to content

Commit

Permalink
Merge pull request #518 from SciML/ap/tstable_tpbvp
Browse files Browse the repository at this point in the history
Make construction of type stable TP BVProblem easier
  • Loading branch information
ChrisRackauckas authored Oct 6, 2023
2 parents b974ce5 + a12a4fc commit ea702db
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 8 deletions.
20 changes: 16 additions & 4 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct StandardBVProblem end
"""
$(TYPEDEF)
"""
struct TwoPointBVProblem end
struct TwoPointBVProblem{iip} end # The iip is needed to make type stable construction easier

@doc doc"""
Expand Down Expand Up @@ -112,7 +112,7 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP}
_tspan = promote_tspan(tspan)
warn_paramtype(p)
prob_type = TP ? TwoPointBVProblem() : StandardBVProblem()
prob_type = TP ? TwoPointBVProblem{iip}() : StandardBVProblem()
# Needed to ensure that `problem_type` doesn't get passed in kwargs
if problem_type === nothing
problem_type = prob_type
Expand Down Expand Up @@ -144,16 +144,28 @@ end
# But we need it for function calls like TwoPointBVProblem{iip}(...) = ...
struct TwoPointBVPFunction{iip} end

@inline TwoPointBVPFunction(args...; kwargs...) = BVPFunction(args...; kwargs..., twopoint=true)
@inline function TwoPointBVPFunction(args...; kwargs...)
return BVPFunction(args...; kwargs..., twopoint = Val(true))
end
@inline function TwoPointBVPFunction{iip}(args...; kwargs...) where {iip}
return BVPFunction{iip}(args...; kwargs..., twopoint=true)
return BVPFunction{iip}(args...; kwargs..., twopoint = Val(true))
end

function TwoPointBVProblem{iip}(f, bc, u0, tspan, p = NullParameters();
bcresid_prototype=nothing, kwargs...) where {iip}
return TwoPointBVProblem(TwoPointBVPFunction{iip}(f, bc; bcresid_prototype), u0, tspan,
p; kwargs...)
end
function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters();
bcresid_prototype=nothing, kwargs...)
return TwoPointBVProblem(TwoPointBVPFunction(f, bc; bcresid_prototype), u0, tspan, p;
kwargs...)
end
function TwoPointBVProblem{iip}(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...)
end
function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
Expand Down
61 changes: 61 additions & 0 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,67 @@ function remake(prob::ODEProblem; f = missing,
end
end

"""
remake(prob::BVProblem; f = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, _kwargs...)
Remake the given `BVProblem`.
"""
function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, _kwargs...)
if tspan === missing
tspan = prob.tspan
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
end

iip = isinplace(prob)

if problem_type === missing
problem_type = prob.problem_type
end

twopoint = problem_type isa TwoPointBVProblem

if bc === missing
bc = prob.bc
end

if f === missing
_f = prob.f
elseif f isa BVPFunction
_f = f
bc = f.bc
elseif specialization(prob.f) === FunctionWrapperSpecialize
ptspan = promote_tspan(tspan)
if iip
_f = BVPFunction{iip, FunctionWrapperSpecialize, twopoint}(wrapfun_iip(f,
(u0, u0, p, ptspan[1])), bc; prob.f.bcresid_prototype)
else
_f = BVPFunction{iip, FunctionWrapperSpecialize, twopoint}(wrapfun_oop(f,
(u0, p, ptspan[1])), bc; prob.f.bcresid_prototype)
end
else
_f = BVPFunction{isinplace(prob), specialization(prob.f), twopoint}(f, bc;
prob.f.bcresid_prototype)
end

if kwargs === missing
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, prob.kwargs..., _kwargs...)
else
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, kwargs...)
end
end

"""
remake(prob::SDEProblem; f = missing, u0 = missing, tspan = missing,
p = missing, noise = missing, noise_rate_prototype = missing,
Expand Down
9 changes: 5 additions & 4 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4064,12 +4064,13 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
end
end

function BVPFunction{iip}(f, bc; twopoint::Bool=false, kwargs...) where {iip}
BVPFunction{iip, FullSpecialize, twopoint}(f, bc; kwargs...)
function BVPFunction{iip}(f, bc; twopoint::Union{Val, Bool}=Val(false),
kwargs...) where {iip}
BVPFunction{iip, FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...)
end
BVPFunction{iip}(f::BVPFunction, bc; kwargs...) where {iip} = f
function BVPFunction(f, bc; twopoint::Bool=false, kwargs...)
BVPFunction{isinplace(f, 4), FullSpecialize, twopoint}(f, bc; kwargs...)
function BVPFunction(f, bc; twopoint::Union{Val, Bool}=Val(false), kwargs...)
BVPFunction{isinplace(f, 4), FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...)
end
BVPFunction(f::BVPFunction; kwargs...) = f

Expand Down

0 comments on commit ea702db

Please sign in to comment.