Skip to content

Commit

Permalink
Fix Shooting Methods
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 27, 2023
1 parent db88fb6 commit a19fcc5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
3 changes: 1 addition & 2 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
kwargs..., verbose, odesolve_kwargs...)
else
u_at_nodes = __multiple_shooting_initialize!(nodes, u_at_nodes, prob, alg,
cur_nshoot, all_nshoots[i - 1], ig;
kwargs..., verbose, odesolve_kwargs...)
cur_nshoot, all_nshoots[i - 1], ig; kwargs..., verbose, odesolve_kwargs...)
end

if __any_sparse_ad(alg.jac_alg)
Expand Down
2 changes: 1 addition & 1 deletion src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ end
function __single_shooting_loss!(resid_, u0_, p, f, bc, u0_size, tspan,
pt::TwoPointBVProblem, (resida_size, residb_size), alg::Shooting, kwargs)
resida = @view resid_[1:prod(resida_size)]
residb = @view resid_[(prod(residb_size) + 1):end]
residb = @view resid_[(prod(resida_size) + 1):end]
resid = (reshape(resida, resida_size), reshape(residb, residb_size))

odeprob = ODEProblem{true}(f, reshape(u0_, u0_size), tspan, p)
Expand Down
8 changes: 6 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,16 @@ function concrete_jacobian_algorithm(jac_alg::BVPJacobianAlgorithm, prob_type,
return BVPJacobianAlgorithm(bc_diffmode, nonbc_diffmode, diffmode)
end

struct BoundaryValueDiffEqTag end

@inline function __default_sparse_ad(x::AbstractArray{T}) where {T}
return isbitstype(T) ? __default_sparse_ad(T) : __default_sparse_ad(first(x))
end
@inline __default_sparse_ad(x::T) where {T} = __default_sparse_ad(T)
@inline __default_sparse_ad(::Type{<:Complex}) = AutoSparseFiniteDiff()
@inline function __default_sparse_ad(T::Type)
return ForwardDiff.can_dual(T) ? AutoSparseForwardDiff() : AutoSparseFiniteDiff()
return ForwardDiff.can_dual(T) ?
AutoSparseForwardDiff(; tag = BoundaryValueDiffEqTag()) : AutoSparseFiniteDiff()
end

@inline function __default_nonsparse_ad(x::AbstractArray{T}) where {T}
Expand All @@ -102,7 +105,8 @@ end
@inline __default_nonsparse_ad(x::T) where {T} = __default_nonsparse_ad(T)
@inline __default_nonsparse_ad(::Type{<:Complex}) = AutoFiniteDiff()
@inline function __default_nonsparse_ad(T::Type)
return ForwardDiff.can_dual(T) ? AutoForwardDiff() : AutoFiniteDiff()
return ForwardDiff.can_dual(T) ? AutoForwardDiff(; tag = BoundaryValueDiffEqTag()) :
AutoFiniteDiff()
end

# This can cause Type Instability
Expand Down

0 comments on commit a19fcc5

Please sign in to comment.