Skip to content

Commit

Permalink
Unify GreedyColoringAlgorithm usage
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY committed Jan 2, 2025
1 parent 65a68e8 commit e7670cc
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase
using DifferentiationInterface: DifferentiationInterface, Constant
using FastClosures: @closure
using ForwardDiff: ForwardDiff
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra
using PreallocationTools: PreallocationTools, DiffCache
using RecursiveArrayTools: VectorOfArray, recursivecopy
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqAscher/src/ascher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractAscher; dt = 0.0,
end

if prob.f.bcjac === nothing
bcjac = construct_bc_jac(prob, bcresid_prototype, prob.problem_type)
bcjac = construct_bcjac(prob, bcresid_prototype)
else
bcjac = prob.f.bcjac
end
Expand Down
9 changes: 4 additions & 5 deletions lib/BoundaryValueDiffEqAscher/src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,10 @@ function Φ!(cache::AscherCache{iip, T}, z, res, pt::TwoPointBVProblem) where {i
copyto!(cache.dmz, dmz)
end

@inline __get_value(z::Vector{<:AbstractArray}) = eltype(first(z)) <: ForwardDiff.Dual ?
[map(x -> x.value, a) for a in z] : z
@inline __get_value(z) = isa(z, ForwardDiff.Dual) ? z.value : z
@inline __get_value(x) = x
@inline __get_value(x::Dual) = ForwardDiff.value(x)
@inline __get_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
@inline __get_value(x::AbstractArray{<:AbstractArray{<:Dual}}) = map(__get_value, x)

function Φ(cache::AscherCache{iip, T}, z, pt::StandardBVProblem) where {iip, T}
(; f, mesh, mesh_dt, ncomp, ny, bc, k, p, zeta, residual, zval, yval, gval, delz, dmz, deldmz, g, w, v, dmzo, ipvtg, ipvtw, TU) = cache
Expand Down Expand Up @@ -739,7 +740,6 @@ end
function vwblok(cache::AscherCache, xcol, hrho, jj, wi, vi, ipvtw, zyval, df, acol, dmzo)
(; jac, k, p, ncomp, ny) = cache
ncy = ncomp + ny
kdy = k * ncy
# initialize wi
i0 = (jj - 1) * ncy
for id in (i0 + 1):(i0 + ncomp)
Expand Down Expand Up @@ -848,7 +848,6 @@ function dmzsol!(cache::AscherCache, v, z, dmz)
for i in 1:n
for j in 1:ncomp
fact = __get_value(z[i][j])
println("fact: ", fact)
for l in 1:kdy
kk, jj = __locate_stage(l, ncy)
dmz[i][kk][jj] = dmz[i][kk][jj] + fact * v[i][l, j]
Expand Down
12 changes: 6 additions & 6 deletions lib/BoundaryValueDiffEqAscher/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ end
return nothing
end

@inline function construct_bc_jac(prob::BVProblem, _, pt::StandardBVProblem)
@inline construct_bcjac(prob, bcresid_prototype) = construct_bcjac(
prob, bcresid_prototype, prob.problem_type)
@inline function construct_bcjac(prob::BVProblem, _, pt::StandardBVProblem)
if isinplace(prob)
bcjac = (df, u, p, t) -> begin
_du = similar(u)
Expand All @@ -184,17 +186,15 @@ end
return
end
else
bcjac = (df, u, p, t) -> begin
_du = prob.f.bc(u, p, t)
bcjac = (u, p, t) -> begin
_f = @closure (du, u) -> (du .= prob.f.bc(u, p, t))
ForwardDiff.jacobian!(df, _f, _du, u)
return
return ForwardDiff.jacobian(_f, u)
end
end
return bcjac
end

@inline function construct_bc_jac(prob::BVProblem, bcresid_prototype, pt::TwoPointBVProblem)
@inline function construct_bcjac(prob::BVProblem, bcresid_prototype, pt::TwoPointBVProblem)
if isinplace(prob)
bcjac = (df, u, p) -> begin
_du = similar(u)
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
bc_diffmode = if alg.jac_alg.bc_diffmode isa AutoSparse
AutoSparse(get_dense_ad(alg.jac_alg.bc_diffmode),
sparsity_detector = SparseConnectivityTracer.TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm(LargestFirst()))
coloring_algorithm = GreedyColoringAlgorithm())
else
alg.jac_alg.bc_diffmode
end
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqShooting/src/single_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function SciMLBase.__solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;
diffmode = if alg.jac_alg.diffmode isa AutoSparse
AutoSparse(get_dense_ad(alg.jac_alg.diffmode),
sparsity_detector = SparseConnectivityTracer.TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm(LargestFirst()))
coloring_algorithm = GreedyColoringAlgorithm())
else
alg.jac_alg.diffmode
end
Expand Down

0 comments on commit e7670cc

Please sign in to comment.