diff --git a/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl b/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl index 5f7f0e29..31226a29 100644 --- a/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl +++ b/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl @@ -12,7 +12,7 @@ using ConcreteStructs: @concrete using DiffEqBase: DiffEqBase using DifferentiationInterface: DifferentiationInterface, Constant using FastClosures: @closure -using ForwardDiff: ForwardDiff +using ForwardDiff: ForwardDiff, Dual using LinearAlgebra using PreallocationTools: PreallocationTools, DiffCache using RecursiveArrayTools: VectorOfArray, recursivecopy diff --git a/lib/BoundaryValueDiffEqAscher/src/ascher.jl b/lib/BoundaryValueDiffEqAscher/src/ascher.jl index 25e685da..94d45315 100644 --- a/lib/BoundaryValueDiffEqAscher/src/ascher.jl +++ b/lib/BoundaryValueDiffEqAscher/src/ascher.jl @@ -136,7 +136,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractAscher; dt = 0.0, end if prob.f.bcjac === nothing - bcjac = construct_bc_jac(prob, bcresid_prototype, prob.problem_type) + bcjac = construct_bcjac(prob, bcresid_prototype) else bcjac = prob.f.bcjac end diff --git a/lib/BoundaryValueDiffEqAscher/src/collocation.jl b/lib/BoundaryValueDiffEqAscher/src/collocation.jl index 819af4f7..592606dd 100644 --- a/lib/BoundaryValueDiffEqAscher/src/collocation.jl +++ b/lib/BoundaryValueDiffEqAscher/src/collocation.jl @@ -313,9 +313,10 @@ function Φ!(cache::AscherCache{iip, T}, z, res, pt::TwoPointBVProblem) where {i copyto!(cache.dmz, dmz) end -@inline __get_value(z::Vector{<:AbstractArray}) = eltype(first(z)) <: ForwardDiff.Dual ? - [map(x -> x.value, a) for a in z] : z -@inline __get_value(z) = isa(z, ForwardDiff.Dual) ? z.value : z +@inline __get_value(x) = x +@inline __get_value(x::Dual) = ForwardDiff.value(x) +@inline __get_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) +@inline __get_value(x::AbstractArray{<:AbstractArray{<:Dual}}) = map(__get_value, x) function Φ(cache::AscherCache{iip, T}, z, pt::StandardBVProblem) where {iip, T} (; f, mesh, mesh_dt, ncomp, ny, bc, k, p, zeta, residual, zval, yval, gval, delz, dmz, deldmz, g, w, v, dmzo, ipvtg, ipvtw, TU) = cache @@ -739,7 +740,6 @@ end function vwblok(cache::AscherCache, xcol, hrho, jj, wi, vi, ipvtw, zyval, df, acol, dmzo) (; jac, k, p, ncomp, ny) = cache ncy = ncomp + ny - kdy = k * ncy # initialize wi i0 = (jj - 1) * ncy for id in (i0 + 1):(i0 + ncomp) @@ -848,7 +848,6 @@ function dmzsol!(cache::AscherCache, v, z, dmz) for i in 1:n for j in 1:ncomp fact = __get_value(z[i][j]) - println("fact: ", fact) for l in 1:kdy kk, jj = __locate_stage(l, ncy) dmz[i][kk][jj] = dmz[i][kk][jj] + fact * v[i][l, j] diff --git a/lib/BoundaryValueDiffEqAscher/src/utils.jl b/lib/BoundaryValueDiffEqAscher/src/utils.jl index 01142a8b..bf9e6b59 100644 --- a/lib/BoundaryValueDiffEqAscher/src/utils.jl +++ b/lib/BoundaryValueDiffEqAscher/src/utils.jl @@ -174,7 +174,9 @@ end return nothing end -@inline function construct_bc_jac(prob::BVProblem, _, pt::StandardBVProblem) +@inline construct_bcjac(prob, bcresid_prototype) = construct_bcjac( + prob, bcresid_prototype, prob.problem_type) +@inline function construct_bcjac(prob::BVProblem, _, pt::StandardBVProblem) if isinplace(prob) bcjac = (df, u, p, t) -> begin _du = similar(u) @@ -184,17 +186,15 @@ end return end else - bcjac = (df, u, p, t) -> begin - _du = prob.f.bc(u, p, t) + bcjac = (u, p, t) -> begin _f = @closure (du, u) -> (du .= prob.f.bc(u, p, t)) - ForwardDiff.jacobian!(df, _f, _du, u) - return + return ForwardDiff.jacobian(_f, u) end end return bcjac end -@inline function construct_bc_jac(prob::BVProblem, bcresid_prototype, pt::TwoPointBVProblem) +@inline function construct_bcjac(prob::BVProblem, bcresid_prototype, pt::TwoPointBVProblem) if isinplace(prob) bcjac = (df, u, p) -> begin _du = similar(u) diff --git a/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl b/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl index df292f9a..06f8cc37 100644 --- a/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl +++ b/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl @@ -185,7 +185,7 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_ bc_diffmode = if alg.jac_alg.bc_diffmode isa AutoSparse AutoSparse(get_dense_ad(alg.jac_alg.bc_diffmode), sparsity_detector = SparseConnectivityTracer.TracerSparsityDetector(), - coloring_algorithm = GreedyColoringAlgorithm(LargestFirst())) + coloring_algorithm = GreedyColoringAlgorithm()) else alg.jac_alg.bc_diffmode end diff --git a/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl b/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl index d7155c91..55264156 100644 --- a/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl +++ b/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl @@ -39,7 +39,7 @@ function SciMLBase.__solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (; diffmode = if alg.jac_alg.diffmode isa AutoSparse AutoSparse(get_dense_ad(alg.jac_alg.diffmode), sparsity_detector = SparseConnectivityTracer.TracerSparsityDetector(), - coloring_algorithm = GreedyColoringAlgorithm(LargestFirst())) + coloring_algorithm = GreedyColoringAlgorithm()) else alg.jac_alg.diffmode end