Skip to content

Commit

Permalink
Merge pull request #1917 from gaurav-arya/ag-sparsediff
Browse files Browse the repository at this point in the history
Support for SparseDiffTools v2
  • Loading branch information
ChrisRackauckas authored May 26, 2023
2 parents 68962b7 + 79d54d8 commit acd09c3
Show file tree
Hide file tree
Showing 32 changed files with 304 additions and 183 deletions.
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <[email protected]>", "Yingbo Ma <maying
version = "6.51.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand Down Expand Up @@ -32,6 +33,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLNLSolve = "e9a6253c-8580-4d32-9898-8661bb511710"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand All @@ -45,7 +47,7 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Adapt = "1.1, 2.0, 3.0"
ArrayInterface = "6, 7"
DataStructures = "0.18"
DiffEqBase = "6.122.0"
DiffEqBase = "6.125.0"
DocStringExtensions = "0.8, 0.9"
ExponentialUtilities = "1.22"
FastBroadcast = "0.1.9, 0.2"
Expand All @@ -67,11 +69,12 @@ Preferences = "1.3"
RecursiveArrayTools = "2.36"
Reexport = "0.2, 1.0"
SciMLBase = "1.90"
SciMLOperators = "0.2.8"
SciMLNLSolve = "0.1"
SimpleNonlinearSolve = "0.1.4"
SimpleUnPack = "1"
SparseDiffTools = "2.3"
PrecompileTools = "1"
SparseDiffTools = "1.26.2"
StaticArrayInterface = "1.2"
StaticArrays = "0.11, 0.12, 1.0"
TruncatedStacktraces = "1.2"
Expand Down
15 changes: 10 additions & 5 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ import DiffEqBase: solve!, step!, initialize!, isadaptive
import DiffEqBase: ODE_DEFAULT_NORM, ODE_DEFAULT_ISOUTOFDOMAIN, ODE_DEFAULT_PROG_MESSAGE,
ODE_DEFAULT_UNSTABLE_CHECK

using DiffEqBase: DiffEqArrayOperator, DEFAULT_UPDATE_FUNC
import SciMLOperators: SciMLOperators, AbstractSciMLOperator, AbstractSciMLScalarOperator,
MatrixOperator, FunctionOperator,
update_coefficients, update_coefficients!, DEFAULT_UPDATE_FUNC,
isconstant

using DiffEqBase: TimeGradientWrapper, UJacobianWrapper, TimeDerivativeWrapper,
UDerivativeWrapper
Expand Down Expand Up @@ -76,7 +79,7 @@ using FastBroadcast: @.., True, False

using IfElse

using SciMLBase: NoInit, _unwrap_val, AbstractSciMLOperator
using SciMLBase: NoInit, _unwrap_val

import DiffEqBase: calculate_residuals, calculate_residuals!, unwrap_cache,
@tight_loop_macros,
Expand All @@ -88,10 +91,12 @@ else
struct OrdinaryDiffEqTag end
end

import SparseDiffTools
import SparseDiffTools: matrix_colors, forwarddiff_color_jacobian!,
import SparseDiffTools: SparseDiffTools, matrix_colors, forwarddiff_color_jacobian!,
forwarddiff_color_jacobian, ForwardColorJacCache,
default_chunk_size, getsize
default_chunk_size, getsize, JacVec

import ADTypes: AbstractADType, AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff,
AutoTracker, AutoZygote, AutoEnzyme

import Polyester
using MacroTools, Adapt
Expand Down
39 changes: 27 additions & 12 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function SciMLBase.forwarddiffs_model(alg::Union{OrdinaryDiffEqAdaptiveImplicitA
DAEAlgorithm,
OrdinaryDiffEqImplicitAlgorithm,
ExponentialAlgorithm})
alg_autodiff(alg)
alg_autodiff(alg) isa AutoForwardDiff
end
SciMLBase.forwarddiffs_model_time(alg::RosenbrockAlgorithm) = true

Expand Down Expand Up @@ -250,7 +250,7 @@ function DiffEqBase.prepare_alg(alg::Union{
!(typeof(prob.f.jac_prototype) <: AbstractSciMLOperator)))
linsolve = LinearSolve.defaultalg(prob.f.jac_prototype, u0)
else
# If mm is a sparse matrix and A is a DiffEqArrayOperator, then let linear
# If mm is a sparse matrix and A is a MatrixOperator, then let linear
# solver choose things later
linsolve = nothing
end
Expand Down Expand Up @@ -316,20 +316,35 @@ function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob)
CompositeAlgorithm(algs, alg.choice_function)
end

function alg_autodiff(alg::OrdinaryDiffEqAlgorithm)
# Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options.
function _alg_autodiff(alg::OrdinaryDiffEqAlgorithm)
error("This algorithm does not have an autodifferentiation option defined.")
end
alg_autodiff(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = AD
alg_autodiff(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = AD
alg_autodiff(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = AD
function alg_autodiff(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where {
CS,
AD
}
AD
_alg_autodiff(::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
_alg_autodiff(::DAEAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
_alg_autodiff(::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
function _alg_autodiff(::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}
}
) where {
CS, AD,
}
Val{AD}()
end

function alg_autodiff(alg)
autodiff = _alg_autodiff(alg)
if autodiff == Val(false)
return AutoFiniteDiff()
elseif autodiff == Val(true)
return AutoForwardDiff()
else
return _unwrap_val(autodiff)
end
end

# end

# alg_autodiff(alg::CompositeAlgorithm) = alg_autodiff(alg.algs[alg.current_alg])
get_current_alg_autodiff(alg, cache) = alg_autodiff(alg)
function get_current_alg_autodiff(alg::CompositeAlgorithm, cache)
Expand Down
13 changes: 13 additions & 0 deletions src/caches/linear_nonlinear_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,19 @@ mutable struct ETD2Fsal{rateType}
lin::rateType
nl::rateType
nlprev::rateType

function ETD2Fsal(lin, nl, nlprev)
if size(lin) == ()
# convert to same type if Number or AbstractSciMLScalarOperator
T = promote_type(eltype.((lin, nl, nlprev))...)

lin = convert(T, lin)
nl = convert(T, nl)
nlprev = convert(T, nlprev)
end

new{typeof(lin)}(lin, nl, nlprev)
end
end
function ETD2Fsal(rate_prototype)
ETD2Fsal(zero(rate_prototype), zero(rate_prototype), zero(rate_prototype))
Expand Down
12 changes: 6 additions & 6 deletions src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
grad_config, reltol, alg, algebraic_vars)
end

struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F} <: OrdinaryDiffEqConstantCache
struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <: OrdinaryDiffEqConstantCache
c₃₂::T
d::T
tf::TF
uf::UF
J::JType
W::WType
linsolve::F
autodiff::Bool
autodiff::AD
end

function Rosenbrock23ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T}
Expand All @@ -181,15 +181,15 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
alg_autodiff(alg))
end

struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F} <: OrdinaryDiffEqConstantCache
struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD} <: OrdinaryDiffEqConstantCache
c₃₂::T
d::T
tf::TF
uf::UF
J::JType
W::WType
linsolve::F
autodiff::Bool
autodiff::AD
end

function Rosenbrock32ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T}
Expand Down Expand Up @@ -416,14 +416,14 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W)

### Rodas methods

struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F} <: OrdinaryDiffEqConstantCache
struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
autodiff::Bool
autodiff::AD
end

@cache mutable struct Rodas4Cache{uType, rateType, uNoUnitsType, JType, WType, TabType,
Expand Down
6 changes: 3 additions & 3 deletions src/dense/stiff_addsteps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
@unpack tf, uf, d = cache
γ = dt * d
tf.u = uprev
if cache.autodiff
if cache.autodiff isa AutoForwardDiff
dT = ForwardDiff.derivative(tf, t)
else
dT = FiniteDiff.finite_difference_derivative(tf, t, dir = sign(dt))
Expand Down Expand Up @@ -180,7 +180,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4ConstantCache,

# Time derivative
tf.u = uprev
if cache.autodiff
if cache.autodiff isa AutoForwardDiff
dT = ForwardDiff.derivative(tf, t)
else
dT = FiniteDiff.finite_difference_derivative(tf, t, dir = sign(dt))
Expand Down Expand Up @@ -566,7 +566,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rosenbrock5ConstantCach

# Time derivative
tf.u = uprev
# if cache.autodiff
# if cache.autodiff isa AutoForwardDiff
# dT = ForwardDiff.derivative(tf, t)
# else
dT = FiniteDiff.finite_difference_derivative(tf, t, dir = sign(dt))
Expand Down
45 changes: 21 additions & 24 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const ROSENBROCK_INV_CUTOFF = 7 # https://github.com/SciML/OrdinaryDiffEq.jl/pull/1539

struct StaticWOperator{isinv, T}
struct StaticWOperator{isinv, T} <: AbstractSciMLOperator{T}
W::T
function StaticWOperator(W::T, callinv = true) where {T}
isinv = size(W, 1) <= ROSENBROCK_INV_CUTOFF
Expand Down Expand Up @@ -202,17 +202,17 @@ mutable struct WOperator{IIP, T,
transform = false) where {IIP}
# TODO: there is definitely a missing interface.
# Tentative interface: `has_concrete` and `concertize(A)`
if J isa Union{Number, DiffEqScalar}
if J isa Union{Number, ScalarOperator}
if transform
_concrete_form = -mass_matrix / gamma + convert(Number, J)
else
_concrete_form = -mass_matrix + gamma * convert(Number, J)
end
_func_cache = nothing
else
AJ = J isa DiffEqArrayOperator ? convert(AbstractMatrix, J) : J
AJ = J isa MatrixOperator ? convert(AbstractMatrix, J) : J
if AJ isa AbstractMatrix
mm = mass_matrix isa DiffEqArrayOperator ?
mm = mass_matrix isa MatrixOperator ?
convert(AbstractMatrix, mass_matrix) : mass_matrix
if AJ isa AbstractSparseMatrix

Expand All @@ -223,7 +223,7 @@ mutable struct WOperator{IIP, T,
#
# Constant operators never refactorize so always use the correct values there
# as well
if gamma == 0 && !(J isa DiffEqArrayOperator && SciMLBase.isconstant(J))
if gamma == 0 && !(J isa MatrixOperator && isconstant(J))
# Workaround https://github.com/JuliaSparse/SparseArrays.jl/issues/190
# Hopefully `rand()` does not match any value in the array (prob ~ 0, with a check)
# Then `one` is required since gamma is zero
Expand Down Expand Up @@ -285,7 +285,7 @@ function WOperator{IIP}(f, u, gamma; transform = false) where {IIP}
J = deepcopy(f.jac_prototype)
if J isa AbstractMatrix
@assert DiffEqBase.has_jac(f) "f needs to have an associated jacobian"
J = DiffEqArrayOperator(J; update_func = f.jac)
J = MatrixOperator(J; update_func! = f.jac)
end
return WOperator{IIP}(mass_matrix, gamma, J, u; transform = transform)
end
Expand All @@ -294,17 +294,16 @@ SciMLBase.isinplace(::WOperator{IIP}, i) where {IIP} = IIP
Base.eltype(W::WOperator) = eltype(W.J)

set_gamma!(W::WOperator, gamma) = (W.gamma = gamma; W)
function DiffEqBase.update_coefficients!(W::WOperator, u, p, t)
function SciMLOperators.update_coefficients!(W::WOperator, u, p, t)
update_coefficients!(W.J, u, p, t)
update_coefficients!(W.mass_matrix, u, p, t)
W.jacvec !== nothing && update_coefficients!(W.jacvec, u, p, t)
!isnothing(W.jacvec) && update_coefficients!(W.jacvec, u, p, t)
W
end

function DiffEqBase.update_coefficients!(J::SparseDiffTools.JacVec, u, p, t)
copyto!(J.x, u)
J.f.t = t
J.f.p = p
function SciMLOperators.update_coefficients!(J::UJacobianWrapper, u, p, t)
J.p = p
J.t = t
end

function Base.convert(::Type{AbstractMatrix}, W::WOperator{IIP}) where {IIP}
Expand Down Expand Up @@ -451,8 +450,7 @@ function do_newJW(integrator, alg, nlsolver, repeat_step)::NTuple{2, Bool}
isfreshJ = isJcurrent(nlsolver, integrator) && !integrator.u_modified
iszero(nlsolver.fast_convergence_cutoff) && return isfs && !isfreshJ, isfs
mm = integrator.f.mass_matrix
is_varying_mm = mm isa DiffEqArrayOperator &&
mm.update_func !== SciMLBase.DEFAULT_UPDATE_FUNC
is_varying_mm = !isconstant(mm)
if isfreshJ
jbad = false
smallstepchange = true
Expand Down Expand Up @@ -675,11 +673,10 @@ function calc_W!(W, integrator, nlsolver::Union{Nothing, AbstractNLSolver}, cach

# calculate W
if W isa WOperator
isnewton(nlsolver) || DiffEqBase.update_coefficients!(W, uprev, p, t) # we will call `update_coefficients!` in NLNewton
isnewton(nlsolver) || update_coefficients!(W, uprev, p, t) # we will call `update_coefficients!` in NLNewton
W.transform = W_transform
set_gamma!(W, dtgamma)
if W.J !== nothing && !(W.J isa SparseDiffTools.JacVec) &&
!(W.J isa AbstractSciMLOperator)
if W.J !== nothing && !(W.J isa AbstractSciMLOperator)
islin, isode = islinearfunction(integrator)
islin ? (J = isode ? f.f : f.f1.f) :
(new_jac && (calc_J!(W.J, integrator, lcache, next_step)))
Expand Down Expand Up @@ -740,7 +737,7 @@ end
else
if !isa(J, AbstractSciMLOperator) && (!isnewton(nlsolver) ||
nlsolver.cache.W.J isa AbstractSciMLOperator)
J = DiffEqArrayOperator(J)
J = MatrixOperator(J)
end
W = WOperator{false}(mass_matrix, dtgamma, J, uprev, cache.W.jacvec;
transform = W_transform)
Expand All @@ -767,7 +764,7 @@ end
end
end
(W isa WOperator && unwrap_alg(integrator, true) isa NewtonAlgorithm) &&
(W = DiffEqBase.update_coefficients!(W, uprev, p, t)) # we will call `update_coefficients!` in NLNewton
(W = update_coefficients!(W, uprev, p, t)) # we will call `update_coefficients!` in NLNewton
is_compos && (integrator.eigen_est = isarray ? constvalue(opnorm(J, Inf)) :
integrator.opts.internalnorm(J, t))
return W
Expand Down Expand Up @@ -853,8 +850,8 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
# be overridden with concrete_jac.

_f = islin ? (isode ? f.f : f.f1.f) : f
jacvec = SparseDiffTools.JacVec(UJacobianWrapper(_f, t, p), copy(u),
OrdinaryDiffEqTag(), autodiff = alg_autodiff(alg))
jacvec = JacVec(UJacobianWrapper(_f, t, p), copy(u), p, t;
autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag())
J = jacvec
W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)

Expand All @@ -869,14 +866,14 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
else
deepcopy(f.jac_prototype)
end
jacvec = SparseDiffTools.JacVec(UJacobianWrapper(_f, t, p), copy(u),
OrdinaryDiffEqTag(), autodiff = alg_autodiff(alg))
jacvec = JacVec(UJacobianWrapper(_f, t, p), copy(u), p, t;
autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag())
W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)

elseif islin || (!IIP && DiffEqBase.has_jac(f))
J = islin ? (isode ? f.f : f.f1.f) : f.jac(uprev, p, t) # unwrap the Jacobian accordingly
if !isa(J, AbstractSciMLOperator)
J = DiffEqArrayOperator(J)
J = MatrixOperator(J)
end
W = WOperator{IIP}(f.mass_matrix, dt, J, u)
else
Expand Down
Loading

0 comments on commit acd09c3

Please sign in to comment.