Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split up the 2Point BVP (Break things again...) #520

Merged
merged 2 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.1.0"
version = "2.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
37 changes: 21 additions & 16 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,18 @@
Note that all features of the `ODESolution` are present in this form.
In both cases, the size of the residual matches the size of the initial condition.

If the bvp is a TwoPointBVProblem it must define either of the following functions
If the bvp is a TwoPointBVProblem then `bc` must be a Tuple `(bca, bcb)` and each of them
must define either of the following functions:

```julia
bc!((resid_a, resid_b), (u_a, u_b), p)
resid_a, resid_b = bc((u_a, u_b), p)
begin
bca!(resid_a, u_a, p)
bcb!(resid_b, u_b, p)
end
begin
resid_a = bca(u_a, p)
resid_b = bcb(u_b, p)
end
```

where `resid_a` and `resid_b` are the residuals at the two endpoints, `u_a` and `u_b` are
Expand All @@ -98,17 +105,16 @@
* `p`: The parameters for the problem. Defaults to `NullParameters`
* `kwargs`: The keyword arguments passed onto the solves.
"""
struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:
AbstractBVProblem{uType, tType, isinplace}
f::F
bc::BF
u0::uType
tspan::tType
p::P
problem_type::PT
kwargs::K

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, bc, u0, tspan,
@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan,
p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP}
_tspan = promote_tspan(tspan)
warn_paramtype(p)
Expand All @@ -119,25 +125,24 @@
else
@assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
end
return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p, problem_type,
kwargs)
return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f),
typeof(problem_type), typeof(kwargs)}(f, u0, _tspan, p, problem_type, kwargs)
end

function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
BVProblem(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
BVProblem(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...)

Check warning on line 133 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L133

Added line #L133 was not covered by tests
end
end

TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2

function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
return BVProblem{iip}(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
return BVProblem{iip}(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...)
end

function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
return BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...)
return BVProblem{isinplace(f)}(f, u0, tspan, p; kwargs...)

Check warning on line 145 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L145

Added line #L145 was not covered by tests
end

# This is mostly a fake stuct and isn't used anywhere
Expand All @@ -163,13 +168,13 @@
end
function TwoPointBVProblem{iip}(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...)
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=Val(true)` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, u0, tspan, p; kwargs...)

Check warning on line 172 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L171-L172

Added lines #L171 - L172 were not covered by tests
end
function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...)
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=Val(true)` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, u0, tspan, p; kwargs...)

Check warning on line 177 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L176-L177

Added lines #L176 - L177 were not covered by tests
end

# Allow previous timeseries solution
Expand Down
2 changes: 1 addition & 1 deletion src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
twopoint = problem_type isa TwoPointBVProblem

if bc === missing
bc = prob.bc
bc = prob.f.bc

Check warning on line 155 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L155

Added line #L155 was not covered by tests
end

if f === missing
Expand Down
33 changes: 30 additions & 3 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4004,9 +4004,31 @@
_bccolorvec = bccolorvec
end

bciip = !twopoint ? isinplace(bc, 4, "bc", iip) : isinplace(bc, 3, "bc", iip)
bciip = if !twopoint
isinplace(bc, 4, "bc", iip)
else
@assert length(bc) == 2
bc = Tuple(bc)
if isinplace(first(bc), 3, "bc", iip) != isinplace(last(bc), 3, "bc", iip)
throw(NonconformingFunctionsError(["bc[1]", "bc[2]"]))

Check warning on line 4013 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4010-L4013

Added lines #L4010 - L4013 were not covered by tests
end
isinplace(first(bc), 3, "bc", iip)

Check warning on line 4015 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4015

Added line #L4015 was not covered by tests
end
jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip
bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", bciip) : bciip
bcjaciip = if bcjac !== nothing
if !twopoint
isinplace(bcjac, 4, "bcjac", bciip)
else
@assert length(bcjac) == 2
bcjac = Tuple(bcjac)
if isinplace(first(bcjac), 3, "bcjac", bciip) != isinplace(last(bcjac), 3, "bcjac", bciip)
throw(NonconformingFunctionsError(["bcjac[1]", "bcjac[2]"]))

Check warning on line 4025 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4022-L4025

Added lines #L4022 - L4025 were not covered by tests
end
isinplace(bcjac, 3, "bcjac", iip)

Check warning on line 4027 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4027

Added line #L4027 was not covered by tests
end
else
bciip
end
tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip
jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip
vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip
Expand All @@ -4029,8 +4051,13 @@
error("bcresid_prototype must be a tuple / indexable collection of length 2 for a inplace TwoPointBVPFunction")
end
if bcresid_prototype !== nothing && length(bcresid_prototype) == 2
bcresid_prototype = ArrayPartition(bcresid_prototype[1], bcresid_prototype[2])
bcresid_prototype = ArrayPartition(first(bcresid_prototype),

Check warning on line 4054 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4054

Added line #L4054 was not covered by tests
last(bcresid_prototype))
end

bccolorvec !== nothing && length(bccolorvec) == 2 && (bccolorvec = Tuple(bccolorvec))

Check warning on line 4058 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4058

Added line #L4058 was not covered by tests

bcjac_prototype !== nothing && length(bcjac_prototype) == 2 && (bcjac_prototype = Tuple(bcjac_prototype))

Check warning on line 4060 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4060

Added line #L4060 was not covered by tests
end

if any(bc_nonconforming)
Expand Down
Loading