From 19fb77321e9945e44bdb674cfda59ff7aa71c9cd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 1 Oct 2023 21:40:12 -0400 Subject: [PATCH 1/5] make construction of type stable TP BVProblem easier --- Project.toml | 2 +- src/problems/bvp_problems.jl | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 364102ac5..332168cd1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.0.7" +version = "2.1.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 2e9f1feab..b15b5d6a9 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -6,7 +6,7 @@ struct StandardBVProblem end """ $(TYPEDEF) """ -struct TwoPointBVProblem end +struct TwoPointBVProblem{iip} end # The iip is needed to make type stable construction easier @doc doc""" @@ -112,7 +112,7 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <: p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP} _tspan = promote_tspan(tspan) warn_paramtype(p) - prob_type = TP ? TwoPointBVProblem() : StandardBVProblem() + prob_type = TP ? TwoPointBVProblem{iip}() : StandardBVProblem() # Needed to ensure that `problem_type` doesn't get passed in kwargs if problem_type === nothing problem_type = prob_type @@ -149,11 +149,21 @@ struct TwoPointBVPFunction{iip} end return BVPFunction{iip}(args...; kwargs..., twopoint=true) end +function TwoPointBVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); + bcresid_prototype=nothing, kwargs...) where {iip} + return TwoPointBVProblem(TwoPointBVPFunction{iip}(f, bc; bcresid_prototype), u0, tspan, p; + kwargs...) +end function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters(); bcresid_prototype=nothing, kwargs...) return TwoPointBVProblem(TwoPointBVPFunction(f, bc; bcresid_prototype), u0, tspan, p; kwargs...) end +function TwoPointBVProblem{iip}(f::AbstractBVPFunction{iip, twopoint}, u0, tspan, + p = NullParameters(); kwargs...) where {iip, twopoint} + @assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`." + return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...) +end function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan, p = NullParameters(); kwargs...) where {iip, twopoint} @assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`." From d08f3f7bfd571cf394e011397e8354b63fceac35 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 1 Oct 2023 21:59:27 -0400 Subject: [PATCH 2/5] Add remake for BVProblem --- src/remake.jl | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/src/remake.jl b/src/remake.jl index 7bbc6de6a..62fb15af2 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -120,6 +120,67 @@ function remake(prob::ODEProblem; f = missing, end end +""" + remake(prob::BVProblem; f = missing, u0 = missing, tspan = missing, + p = missing, kwargs = missing, problem_type = missing, _kwargs...) + +Remake the given `BVProblem`. +""" +function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing, + p = missing, kwargs = missing, problem_type = missing, _kwargs...) + if tspan === missing + tspan = prob.tspan + end + + if p === missing && u0 === missing + p, u0 = prob.p, prob.u0 + else # at least one of them has a value + if p === missing + p = prob.p + end + if u0 === missing + u0 = prob.u0 + end + end + + iip = isinplace(prob) + + if problem_type === missing + problem_type = prob.problem_type + end + + twopoint = problem_type isa TwoPointBVProblem + + if bc === missing + bc = prob.bc + end + + if f === missing + _f = prob.f + elseif f isa BVPFunction + _f = f + bc = f.bc + elseif specialization(prob.f) === FunctionWrapperSpecialize + ptspan = promote_tspan(tspan) + if iip + _f = BVPFunction{iip, FunctionWrapperSpecialize, twopoint}(wrapfun_iip(f, + (u0, u0, p, ptspan[1])), bc; prob.f.bcresid_prototype) + else + _f = BVPFunction{iip, FunctionWrapperSpecialize, twopoint}(wrapfun_oop(f, + (u0, p, ptspan[1])), bc; prob.f.bcresid_prototype) + end + else + _f = BVPFunction{isinplace(prob), specialization(prob.f), twopoint}(f, bc; + prob.f.bcresid_prototype) + end + + if kwargs === missing + BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, prob.kwargs..., _kwargs...) + else + BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, kwargs...) + end +end + """ remake(prob::SDEProblem; f = missing, u0 = missing, tspan = missing, p = missing, noise = missing, noise_rate_prototype = missing, From a12a4fc0896fccb93a9fd77e91aeb18dd5958bd1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 2 Oct 2023 10:57:31 -0400 Subject: [PATCH 3/5] Allow twopoint to be a Val --- src/problems/bvp_problems.jl | 10 ++++++---- src/scimlfunctions.jl | 9 +++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index b15b5d6a9..fd36c1908 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -144,15 +144,17 @@ end # But we need it for function calls like TwoPointBVProblem{iip}(...) = ... struct TwoPointBVPFunction{iip} end -@inline TwoPointBVPFunction(args...; kwargs...) = BVPFunction(args...; kwargs..., twopoint=true) +@inline function TwoPointBVPFunction(args...; kwargs...) + return BVPFunction(args...; kwargs..., twopoint = Val(true)) +end @inline function TwoPointBVPFunction{iip}(args...; kwargs...) where {iip} - return BVPFunction{iip}(args...; kwargs..., twopoint=true) + return BVPFunction{iip}(args...; kwargs..., twopoint = Val(true)) end function TwoPointBVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); bcresid_prototype=nothing, kwargs...) where {iip} - return TwoPointBVProblem(TwoPointBVPFunction{iip}(f, bc; bcresid_prototype), u0, tspan, p; - kwargs...) + return TwoPointBVProblem(TwoPointBVPFunction{iip}(f, bc; bcresid_prototype), u0, tspan, + p; kwargs...) end function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters(); bcresid_prototype=nothing, kwargs...) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index c74808df2..5cb5d6bb2 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4064,12 +4064,13 @@ function BVPFunction{iip, specialize, twopoint}(f, bc; end end -function BVPFunction{iip}(f, bc; twopoint::Bool=false, kwargs...) where {iip} - BVPFunction{iip, FullSpecialize, twopoint}(f, bc; kwargs...) +function BVPFunction{iip}(f, bc; twopoint::Union{Val, Bool}=Val(false), + kwargs...) where {iip} + BVPFunction{iip, FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...) end BVPFunction{iip}(f::BVPFunction, bc; kwargs...) where {iip} = f -function BVPFunction(f, bc; twopoint::Bool=false, kwargs...) - BVPFunction{isinplace(f, 4), FullSpecialize, twopoint}(f, bc; kwargs...) +function BVPFunction(f, bc; twopoint::Union{Val, Bool}=Val(false), kwargs...) + BVPFunction{isinplace(f, 4), FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...) end BVPFunction(f::BVPFunction; kwargs...) = f From 90bcabcde640204847f0a884c33de8a1153efb5e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Oct 2023 15:48:21 -0400 Subject: [PATCH 4/5] Split up the 2Point BVP --- Project.toml | 2 +- src/problems/bvp_problems.jl | 37 ++++++++++++++++++++---------------- src/remake.jl | 2 +- src/scimlfunctions.jl | 33 +++++++++++++++++++++++++++++--- 4 files changed, 53 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index 5a5d89a8e..c8926b398 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.1.0" +version = "2.2.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index fd36c1908..0867e99d2 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -71,11 +71,18 @@ time points, and for shooting type methods `u=sol` the ODE solution. Note that all features of the `ODESolution` are present in this form. In both cases, the size of the residual matches the size of the initial condition. -If the bvp is a TwoPointBVProblem it must define either of the following functions +If the bvp is a TwoPointBVProblem then `bc` must be a Tuple `(bca, bcb)` and each of them +must define either of the following functions: ```julia -bc!((resid_a, resid_b), (u_a, u_b), p) -resid_a, resid_b = bc((u_a, u_b), p) +begin + bca!(resid_a, u_a, p) + bcb!(resid_b, u_b, p) +end +begin + resid_a = bca(u_a, p) + resid_b = bcb(u_b, p) +end ``` where `resid_a` and `resid_b` are the residuals at the two endpoints, `u_a` and `u_b` are @@ -98,17 +105,16 @@ every solve call. * `p`: The parameters for the problem. Defaults to `NullParameters` * `kwargs`: The keyword arguments passed onto the solves. """ -struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <: +struct BVProblem{uType, tType, isinplace, P, F, PT, K} <: AbstractBVProblem{uType, tType, isinplace} f::F - bc::BF u0::uType tspan::tType p::P problem_type::PT kwargs::K - @add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, bc, u0, tspan, + @add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan, p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP} _tspan = promote_tspan(tspan) warn_paramtype(p) @@ -119,13 +125,12 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <: else @assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use." end - return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc), - typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p, problem_type, - kwargs) + return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f), + typeof(problem_type), typeof(kwargs)}(f, u0, _tspan, p, problem_type, kwargs) end function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip} - BVProblem(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...) + BVProblem(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...) end end @@ -133,11 +138,11 @@ TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2 function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) iip = isinplace(f, 4) - return BVProblem{iip}(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...) + return BVProblem{iip}(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...) end function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...) - return BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...) + return BVProblem{isinplace(f)}(f, u0, tspan, p; kwargs...) end # This is mostly a fake stuct and isn't used anywhere @@ -163,13 +168,13 @@ function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters(); end function TwoPointBVProblem{iip}(f::AbstractBVPFunction{iip, twopoint}, u0, tspan, p = NullParameters(); kwargs...) where {iip, twopoint} - @assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`." - return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...) + @assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=Val(true)` during the construction of the `BVPFunction`." + return BVProblem{iip}(f, u0, tspan, p; kwargs...) end function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan, p = NullParameters(); kwargs...) where {iip, twopoint} - @assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`." - return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...) + @assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=Val(true)` during the construction of the `BVPFunction`." + return BVProblem{iip}(f, u0, tspan, p; kwargs...) end # Allow previous timeseries solution diff --git a/src/remake.jl b/src/remake.jl index 62fb15af2..8fe0feb1f 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -152,7 +152,7 @@ function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan twopoint = problem_type isa TwoPointBVProblem if bc === missing - bc = prob.bc + bc = prob.f.bc end if f === missing diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 5cb5d6bb2..65bcfbaa5 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4004,9 +4004,31 @@ function BVPFunction{iip, specialize, twopoint}(f, bc; _bccolorvec = bccolorvec end - bciip = !twopoint ? isinplace(bc, 4, "bc", iip) : isinplace(bc, 3, "bc", iip) + bciip = if !twopoint + isinplace(bc, 4, "bc", iip) + else + @assert length(bc) == 2 + bc = Tuple(bc) + if isinplace(first(bc), 3, "bc", iip) != isinplace(last(bc), 3, "bc", iip) + throw(NonconformingFunctionsError(["bc[1]", "bc[2]"])) + end + isinplace(first(bc), 3, "bc", iip) + end jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip - bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", bciip) : bciip + bcjaciip = if bcjac !== nothing + if twopoint + isinplace(bcjac, 4, "bcjac", bciip) + else + @assert length(bcjac) == 2 + bcjac = Tuple(bcjac) + if isinplace(first(bcjac), 3, "bcjac", bciip) != isinplace(last(bcjac), 3, "bcjac", bciip) + throw(NonconformingFunctionsError(["bcjac[1]", "bcjac[2]"])) + end + isinplace(bcjac, 3, "bcjac", iip) + end + else + bciip + end tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip @@ -4029,8 +4051,13 @@ function BVPFunction{iip, specialize, twopoint}(f, bc; error("bcresid_prototype must be a tuple / indexable collection of length 2 for a inplace TwoPointBVPFunction") end if bcresid_prototype !== nothing && length(bcresid_prototype) == 2 - bcresid_prototype = ArrayPartition(bcresid_prototype[1], bcresid_prototype[2]) + bcresid_prototype = ArrayPartition(first(bcresid_prototype), + last(bcresid_prototype)) end + + bccolorvec !== nothing && length(bccolorvec) == 2 && (bccolorvec = Tuple(bccolorvec)) + + bcjac_prototype !== nothing && length(bcjac_prototype) == 2 && (bcjac_prototype = Tuple(bcjac_prototype)) end if any(bc_nonconforming) From b404b4762e0ce1d662e1c33551bbe9417dfc8ef6 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 6 Oct 2023 23:30:30 +0200 Subject: [PATCH 5/5] fix typo --- src/scimlfunctions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 65bcfbaa5..c41f7753e 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4016,7 +4016,7 @@ function BVPFunction{iip, specialize, twopoint}(f, bc; end jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip bcjaciip = if bcjac !== nothing - if twopoint + if !twopoint isinplace(bcjac, 4, "bcjac", bciip) else @assert length(bcjac) == 2