Skip to content

Commit

Permalink
Merge branch 'master' into lh/PythonCall-extension
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas authored Oct 7, 2023
2 parents 44e8f43 + 5d0d7e0 commit bfd023d
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.1.0"
version = "2.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
53 changes: 35 additions & 18 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct StandardBVProblem end
"""
$(TYPEDEF)
"""
struct TwoPointBVProblem end
struct TwoPointBVProblem{iip} end # The iip is needed to make type stable construction easier

@doc doc"""
Expand Down Expand Up @@ -71,11 +71,18 @@ time points, and for shooting type methods `u=sol` the ODE solution.
Note that all features of the `ODESolution` are present in this form.
In both cases, the size of the residual matches the size of the initial condition.
If the bvp is a TwoPointBVProblem it must define either of the following functions
If the bvp is a TwoPointBVProblem then `bc` must be a Tuple `(bca, bcb)` and each of them
must define either of the following functions:
```julia
bc!((resid_a, resid_b), (u_a, u_b), p)
resid_a, resid_b = bc((u_a, u_b), p)
begin
bca!(resid_a, u_a, p)
bcb!(resid_b, u_b, p)
end
begin
resid_a = bca(u_a, p)
resid_b = bcb(u_b, p)
end
```
where `resid_a` and `resid_b` are the residuals at the two endpoints, `u_a` and `u_b` are
Expand All @@ -98,67 +105,77 @@ every solve call.
* `p`: The parameters for the problem. Defaults to `NullParameters`
* `kwargs`: The keyword arguments passed onto the solves.
"""
struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:
AbstractBVProblem{uType, tType, isinplace}
f::F
bc::BF
u0::uType
tspan::tType
p::P
problem_type::PT
kwargs::K

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, bc, u0, tspan,
@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan,
p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
prob_type = TP ? TwoPointBVProblem() : StandardBVProblem()
prob_type = TP ? TwoPointBVProblem{iip}() : StandardBVProblem()
# Needed to ensure that `problem_type` doesn't get passed in kwargs
if problem_type === nothing
problem_type = prob_type
else
@assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
end
return new{typeof(_u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f, bc, _u0, _tspan, p, problem_type,
kwargs)
return new{typeof(_u0), typeof(_tspan), iip, typeof(p), typeof(f),
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs)
end

function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
BVProblem(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
BVProblem(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...)
end
end

TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2

function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
return BVProblem{iip}(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
return BVProblem{iip}(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...)
end

function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
return BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...)
return BVProblem{isinplace(f)}(f, u0, tspan, p; kwargs...)
end

# This is mostly a fake stuct and isn't used anywhere
# But we need it for function calls like TwoPointBVProblem{iip}(...) = ...
struct TwoPointBVPFunction{iip} end

@inline TwoPointBVPFunction(args...; kwargs...) = BVPFunction(args...; kwargs..., twopoint=true)
@inline function TwoPointBVPFunction(args...; kwargs...)
return BVPFunction(args...; kwargs..., twopoint = Val(true))
end
@inline function TwoPointBVPFunction{iip}(args...; kwargs...) where {iip}
return BVPFunction{iip}(args...; kwargs..., twopoint=true)
return BVPFunction{iip}(args...; kwargs..., twopoint = Val(true))
end

function TwoPointBVProblem{iip}(f, bc, u0, tspan, p = NullParameters();
bcresid_prototype=nothing, kwargs...) where {iip}
return TwoPointBVProblem(TwoPointBVPFunction{iip}(f, bc; bcresid_prototype), u0, tspan,
p; kwargs...)
end
function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters();
bcresid_prototype=nothing, kwargs...)
return TwoPointBVProblem(TwoPointBVPFunction(f, bc; bcresid_prototype), u0, tspan, p;
kwargs...)
end
function TwoPointBVProblem{iip}(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=Val(true)` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, u0, tspan, p; kwargs...)
end
function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...)
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=Val(true)` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, u0, tspan, p; kwargs...)
end

# Allow previous timeseries solution
Expand Down
61 changes: 61 additions & 0 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,67 @@ function remake(prob::ODEProblem; f = missing,
end
end

"""
remake(prob::BVProblem; f = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, _kwargs...)
Remake the given `BVProblem`.
"""
function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, _kwargs...)
if tspan === missing
tspan = prob.tspan
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
end

iip = isinplace(prob)

if problem_type === missing
problem_type = prob.problem_type
end

twopoint = problem_type isa TwoPointBVProblem

if bc === missing
bc = prob.f.bc
end

if f === missing
_f = prob.f
elseif f isa BVPFunction
_f = f
bc = f.bc
elseif specialization(prob.f) === FunctionWrapperSpecialize
ptspan = promote_tspan(tspan)
if iip
_f = BVPFunction{iip, FunctionWrapperSpecialize, twopoint}(wrapfun_iip(f,
(u0, u0, p, ptspan[1])), bc; prob.f.bcresid_prototype)
else
_f = BVPFunction{iip, FunctionWrapperSpecialize, twopoint}(wrapfun_oop(f,
(u0, p, ptspan[1])), bc; prob.f.bcresid_prototype)
end
else
_f = BVPFunction{isinplace(prob), specialization(prob.f), twopoint}(f, bc;
prob.f.bcresid_prototype)
end

if kwargs === missing
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, prob.kwargs..., _kwargs...)
else
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, kwargs...)
end
end

"""
remake(prob::SDEProblem; f = missing, u0 = missing, tspan = missing,
p = missing, noise = missing, noise_rate_prototype = missing,
Expand Down
42 changes: 35 additions & 7 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4024,9 +4024,31 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
_bccolorvec = bccolorvec
end

bciip = !twopoint ? isinplace(bc, 4, "bc", iip) : isinplace(bc, 3, "bc", iip)
bciip = if !twopoint
isinplace(bc, 4, "bc", iip)
else
@assert length(bc) == 2
bc = Tuple(bc)
if isinplace(first(bc), 3, "bc", iip) != isinplace(last(bc), 3, "bc", iip)
throw(NonconformingFunctionsError(["bc[1]", "bc[2]"]))
end
isinplace(first(bc), 3, "bc", iip)
end
jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip
bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", bciip) : bciip
bcjaciip = if bcjac !== nothing
if !twopoint
isinplace(bcjac, 4, "bcjac", bciip)
else
@assert length(bcjac) == 2
bcjac = Tuple(bcjac)
if isinplace(first(bcjac), 3, "bcjac", bciip) != isinplace(last(bcjac), 3, "bcjac", bciip)
throw(NonconformingFunctionsError(["bcjac[1]", "bcjac[2]"]))
end
isinplace(bcjac, 3, "bcjac", iip)
end
else
bciip
end
tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip
jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip
vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip
Expand All @@ -4049,8 +4071,13 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
error("bcresid_prototype must be a tuple / indexable collection of length 2 for a inplace TwoPointBVPFunction")
end
if bcresid_prototype !== nothing && length(bcresid_prototype) == 2
bcresid_prototype = ArrayPartition(bcresid_prototype[1], bcresid_prototype[2])
bcresid_prototype = ArrayPartition(first(bcresid_prototype),
last(bcresid_prototype))
end

bccolorvec !== nothing && length(bccolorvec) == 2 && (bccolorvec = Tuple(bccolorvec))

bcjac_prototype !== nothing && length(bcjac_prototype) == 2 && (bcjac_prototype = Tuple(bcjac_prototype))
end

if any(bc_nonconforming)
Expand Down Expand Up @@ -4086,12 +4113,13 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
end
end

function BVPFunction{iip}(f, bc; twopoint::Bool=false, kwargs...) where {iip}
BVPFunction{iip, FullSpecialize, twopoint}(f, bc; kwargs...)
function BVPFunction{iip}(f, bc; twopoint::Union{Val, Bool}=Val(false),
kwargs...) where {iip}
BVPFunction{iip, FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...)
end
BVPFunction{iip}(f::BVPFunction, bc; kwargs...) where {iip} = f
function BVPFunction(f, bc; twopoint::Bool=false, kwargs...)
BVPFunction{isinplace(f, 4), FullSpecialize, twopoint}(f, bc; kwargs...)
function BVPFunction(f, bc; twopoint::Union{Val, Bool}=Val(false), kwargs...)
BVPFunction{isinplace(f, 4), FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...)
end
BVPFunction(f::BVPFunction; kwargs...) = f

Expand Down

0 comments on commit bfd023d

Please sign in to comment.