Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for SparseDiffTools v2 #1917

Merged
merged 58 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
2ad1812
Support SparseDiffTools v2
gaurav-arya Mar 26, 2023
ea9ef65
Tweak version format
gaurav-arya Mar 26, 2023
2ef6999
Drop support for SparseDiffTools v1
gaurav-arya Mar 26, 2023
9c8c24b
Update autodiff selection
gaurav-arya Mar 28, 2023
febb9e3
Fix JacVec calls
gaurav-arya Mar 28, 2023
6466845
Update usage of alg_autodiff
gaurav-arya Apr 10, 2023
c87405d
Merge remote-tracking branch 'origin/master' into ag-sparsediff
gaurav-arya Apr 10, 2023
6025e37
added scimloperators v0.2.2
vpuri3 May 8, 2023
c2e13f3
diffeqarrayop -> matrixop
vpuri3 May 8, 2023
f6ce742
comments and fixes
vpuri3 May 8, 2023
ef9d6fd
Merge branch 'master' into ag-sparsediff
vpuri3 May 8, 2023
f125b59
diffeqarrayop -> matrixop in tests
vpuri3 May 8, 2023
060e4be
diffeqlinearop -> scimlop in newton.jl
vpuri3 May 8, 2023
67af2b2
rosenbroock accepts ADTypes
vpuri3 May 8, 2023
2e47fe5
autodiff fix in newton.jl
vpuri3 May 8, 2023
a17a77e
fix ad error in nlsolve/utisl.jl
vpuri3 May 8, 2023
4a3a4de
bump scimlops to 0.2.3
vpuri3 May 8, 2023
f50d52c
AD bug in initialize_dae
vpuri3 May 8, 2023
c86a051
using/import fixes in ODE.jl
vpuri3 May 9, 2023
15c92ac
import/using fixes in ODE.jl
vpuri3 May 9, 2023
ea3e9b8
fix autodiff error in stiff_addsteps.jl
vpuri3 May 9, 2023
98b01e8
tests unexpectedly passing
vpuri3 May 9, 2023
caa7cd3
TODO - fix CayleyEuler cache to use SciMLOps
vpuri3 May 9, 2023
70166a6
fix CayleyEuler
vpuri3 May 9, 2023
67e66f0
isconstant fix
vpuri3 May 9, 2023
309693e
all splitode/ splitfunction tests are passing
vpuri3 May 9, 2023
9132722
update scimlops compat
vpuri3 May 9, 2023
5ac7e62
reverting WOperator split function decision
vpuri3 May 9, 2023
959d4b7
fix ETD2Fsal constructor for scimlops
vpuri3 May 11, 2023
e4244fa
fixed ETDFsal constructor
vpuri3 May 12, 2023
049f5b2
bump scimlops compat
vpuri3 May 12, 2023
fd1c6c0
pass autodiff tag to jacvec https://github.com/JuliaDiff/SparseDiffTo…
vpuri3 May 12, 2023
6393ec1
unnecessary dots mess with ScalarOperators
vpuri3 May 12, 2023
1146441
rm more unnecessary dots
vpuri3 May 12, 2023
cbfe2ce
fix typo
vpuri3 May 12, 2023
4f9aa27
fix update func in linear_method_tests
vpuri3 May 12, 2023
d0ffa6c
split/sdirk tests remain broken
vpuri3 May 12, 2023
ddc9791
update scimlop compat
vpuri3 May 15, 2023
9c0fa84
sparsedifftools compat
vpuri3 May 16, 2023
0b01905
rm comments from CayleyEuler in linear_perform_step
vpuri3 May 16, 2023
1ef9d90
rm Kencarp3, CFNLIRK3 split tests woperator doesn't support split pro…
vpuri3 May 16, 2023
0b6594e
fix update behavious in mass matrix tests
vpuri3 May 16, 2023
5ccc275
flattening to Nx1 solves the iterativesolvers indexing error
vpuri3 May 16, 2023
8fd1f1f
flattening PDE solve
vpuri3 May 16, 2023
3e7cc18
diffeqbase compat
vpuri3 May 16, 2023
a5ae19e
precond, nojac tests passing
vpuri3 May 16, 2023
a96c88f
Merge branch 'master' into sparsediff
vpuri3 May 17, 2023
5ff1759
switch out iterativesolvers --> krylovjl to fix inf/nan errors
vpuri3 May 17, 2023
908cd69
add more checks to norecompile
vpuri3 May 17, 2023
c7e778e
Fix UJacobianWrapper update coeffs
gaurav-arya May 18, 2023
83d38d4
Test update coefficients of jacobian operator
gaurav-arya May 21, 2023
477a58b
Fix typo
gaurav-arya May 21, 2023
af44803
scalar test fixes
vpuri3 May 22, 2023
e37f0ef
rm comment in derv_wrappers
vpuri3 May 22, 2023
f6122eb
retriggering CI with dummy commit
vpuri3 May 23, 2023
1da5423
Fix forwarddiffs_model
gaurav-arya May 24, 2023
6778668
update downstream compat
vpuri3 May 24, 2023
79d54d8
Merge branch 'ag-sparsediff' of github.com:gaurav-arya/OrdinaryDiffEq…
vpuri3 May 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <[email protected]>", "Yingbo Ma <maying
version = "6.49.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand Down Expand Up @@ -69,7 +70,7 @@ SciMLNLSolve = "0.1"
SimpleNonlinearSolve = "0.1.4"
SimpleUnPack = "1"
SnoopPrecompile = "1"
SparseDiffTools = "1.26.2"
SparseDiffTools = "2"
StaticArrayInterface = "1.2"
StaticArrays = "0.11, 0.12, 1.0"
TruncatedStacktraces = "1.2"
Expand Down
4 changes: 4 additions & 0 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ import DiffEqBase: solve!, step!, initialize!, isadaptive
import DiffEqBase: ODE_DEFAULT_NORM, ODE_DEFAULT_ISOUTOFDOMAIN, ODE_DEFAULT_PROG_MESSAGE,
ODE_DEFAULT_UNSTABLE_CHECK

# TODO: adjust all uses of the below two
using DiffEqBase: DiffEqArrayOperator, DEFAULT_UPDATE_FUNC

using DiffEqBase: TimeGradientWrapper, UJacobianWrapper, TimeDerivativeWrapper,
UDerivativeWrapper

using DiffEqBase: DEIntegrator

import SciMLBase: update_coefficients!

import RecursiveArrayTools: chain, recursivecopy!

using SimpleUnPack, ForwardDiff, RecursiveArrayTools,
Expand Down Expand Up @@ -90,6 +93,7 @@ import SparseDiffTools
import SparseDiffTools: matrix_colors, forwarddiff_color_jacobian!,
forwarddiff_color_jacobian, ForwardColorJacCache,
default_chunk_size, getsize
using ADTypes
gaurav-arya marked this conversation as resolved.
Show resolved Hide resolved

import Polyester
using MacroTools, Adapt
Expand Down
26 changes: 20 additions & 6 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,20 +316,34 @@ function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob)
CompositeAlgorithm(algs, alg.choice_function)
end

function alg_autodiff(alg::OrdinaryDiffEqAlgorithm)
# Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options.
function _alg_autodiff(alg::OrdinaryDiffEqAlgorithm)
error("This algorithm does not have an autodifferentiation option defined.")
end
alg_autodiff(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = AD
alg_autodiff(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = AD
alg_autodiff(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = AD
function alg_autodiff(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
_alg_autodiff(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
_alg_autodiff(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
_alg_autodiff(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
function _alg_autodiff(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where {
CS,
AD
}
AD
Val{AD}()
end

function alg_autodiff(alg)
autodiff = _alg_autodiff(alg)
if autodiff == Val(false)
return AutoFiniteDiff()
elseif autodiff == Val(true)
return AutoForwardDiff()
else
return _unwrap_val(autodiff)
end
end

# end

# alg_autodiff(alg::CompositeAlgorithm) = alg_autodiff(alg.algs[alg.current_alg])
get_current_alg_autodiff(alg, cache) = alg_autodiff(alg)
function get_current_alg_autodiff(alg::CompositeAlgorithm, cache)
Expand Down
12 changes: 6 additions & 6 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,14 @@ SciMLBase.isinplace(::WOperator{IIP}, i) where {IIP} = IIP
Base.eltype(W::WOperator) = eltype(W.J)

set_gamma!(W::WOperator, gamma) = (W.gamma = gamma; W)
function DiffEqBase.update_coefficients!(W::WOperator, u, p, t)
function update_coefficients!(W::WOperator, u, p, t)
update_coefficients!(W.J, u, p, t)
update_coefficients!(W.mass_matrix, u, p, t)
W.jacvec !== nothing && update_coefficients!(W.jacvec, u, p, t)
W
end

function DiffEqBase.update_coefficients!(J::SparseDiffTools.JacVec, u, p, t)
function DiffEqBase.update_coefficients!(J::FunctionOperator{UJacobianWrapper}, u, p, t)
copyto!(J.x, u)
J.f.t = t
J.f.p = p
Expand Down Expand Up @@ -451,6 +451,7 @@ function do_newJW(integrator, alg, nlsolver, repeat_step)::NTuple{2, Bool}
isfreshJ = isJcurrent(nlsolver, integrator) && !integrator.u_modified
iszero(nlsolver.fast_convergence_cutoff) && return isfs && !isfreshJ, isfs
mm = integrator.f.mass_matrix
# TODO: adjust
is_varying_mm = mm isa DiffEqArrayOperator &&
mm.update_func !== SciMLBase.DEFAULT_UPDATE_FUNC
if isfreshJ
Expand Down Expand Up @@ -678,8 +679,7 @@ function calc_W!(W, integrator, nlsolver::Union{Nothing, AbstractNLSolver}, cach
isnewton(nlsolver) || DiffEqBase.update_coefficients!(W, uprev, p, t) # we will call `update_coefficients!` in NLNewton
W.transform = W_transform
set_gamma!(W, dtgamma)
if W.J !== nothing && !(W.J isa SparseDiffTools.JacVec) &&
!(W.J isa AbstractSciMLOperator)
if W.J !== nothing && !(W.J isa AbstractSciMLOperator)
islin, isode = islinearfunction(integrator)
islin ? (J = isode ? f.f : f.f1.f) :
(new_jac && (calc_J!(W.J, integrator, lcache, next_step)))
Expand Down Expand Up @@ -854,7 +854,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},

_f = islin ? (isode ? f.f : f.f1.f) : f
jacvec = SparseDiffTools.JacVec(UJacobianWrapper(_f, t, p), copy(u),
OrdinaryDiffEqTag(), autodiff = alg_autodiff(alg))
p, t; autodiff = alg_autodiff(alg))
J = jacvec
W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)

Expand All @@ -870,7 +870,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
deepcopy(f.jac_prototype)
end
jacvec = SparseDiffTools.JacVec(UJacobianWrapper(_f, t, p), copy(u),
OrdinaryDiffEqTag(), autodiff = alg_autodiff(alg))
p, t; autodiff = alg_autodiff(alg))
W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)

elseif islin || (!IIP && DiffEqBase.has_jac(f))
Expand Down
39 changes: 26 additions & 13 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ function derivative!(df::AbstractArray{<:Number}, f,
x::Union{Number, AbstractArray{<:Number}}, fx::AbstractArray{<:Number},
integrator, grad_config)
alg = unwrap_alg(integrator, true)
tmp = length(x) # We calculate derivative for all elements in gradient
if alg_autodiff(alg)
tmp = length(x) # We calculate derivtive for all elements in gradient
if alg_autodiff(alg) isa AutoForwardDiff
T = if standardtag(alg)
typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(df)))
else
Expand All @@ -102,7 +102,7 @@ function derivative!(df::AbstractArray{<:Number}, f,

df .= first.(ForwardDiff.partials.(grad_config))
integrator.stats.nf += 1
else
elseif alg_autodiff(alg) isa AutoFiniteDiff
FiniteDiff.finite_difference_gradient!(df, f, x, grad_config,
dir = diffdir(integrator))
fdtype = alg_difftype(alg)
Expand All @@ -113,6 +113,8 @@ function derivative!(df::AbstractArray{<:Number}, f,
end
end
integrator.stats.nf += tmp
else
error("$alg_autodiff not yet supported in derivative! function")
end
nothing
end
Expand All @@ -122,7 +124,7 @@ function derivative(f, x::Union{Number, AbstractArray{<:Number}},
local d
tmp = length(x) # We calculate derivative for all elements in gradient
alg = unwrap_alg(integrator, true)
if alg_autodiff(alg)
if alg_autodiff(alg) isa AutoForwardDiff
integrator.stats.nf += 1
if integrator.iter == 1
try
Expand All @@ -133,14 +135,16 @@ function derivative(f, x::Union{Number, AbstractArray{<:Number}},
else
d = ForwardDiff.derivative(f, x)
end
else
elseif alg_autodiff(alg) isa AutoFiniteDiff
d = FiniteDiff.finite_difference_derivative(f, x, alg_difftype(alg),
dir = diffdir(integrator))
if alg_difftype(alg) === Val{:central} || alg_difftype(alg) === Val{:forward}
tmp *= 2
end
integrator.stats.nf += tmp
d
else
error("$alg_autodiff not yet supported in derivative function")
end
end

Expand Down Expand Up @@ -186,7 +190,7 @@ end
function jacobian(f, x, integrator)
alg = unwrap_alg(integrator, true)
local tmp
if alg_autodiff(alg)
if alg_autodiff(alg) isa AutoForwardDiff
if integrator.iter == 1
try
J, tmp = jacobian_autodiff(f, x, integrator.f, alg)
Expand All @@ -196,12 +200,14 @@ function jacobian(f, x, integrator)
else
J, tmp = jacobian_autodiff(f, x, integrator.f, alg)
end
else
elseif alg_autodiff(alg) isa AutoFiniteDiff
jac_prototype = integrator.f.jac_prototype
sparsity, colorvec = sparsity_colorvec(integrator.f, x)
dir = diffdir(integrator)
J, tmp = jacobian_finitediff(f, x, alg_difftype(alg), dir, colorvec, sparsity,
jac_prototype)
else
bleh
end
integrator.stats.nf += tmp
J
Expand All @@ -222,7 +228,7 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
fx::AbstractArray{<:Number}, integrator::DiffEqBase.DEIntegrator,
jac_config)
alg = unwrap_alg(integrator, true)
if alg_autodiff(alg)
if alg_autodiff(alg) isa AutoForwardDiff
if integrator.iter == 1
try
forwarddiff_color_jacobian!(J, f, x, jac_config)
Expand All @@ -233,7 +239,7 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
forwarddiff_color_jacobian!(J, f, x, jac_config)
end
integrator.stats.nf += 1
else
elseif alg_autodiff(alg) isa AutoFiniteDiff
isforward = alg_difftype(alg) === Val{:forward}
if isforward
forwardcache = get_tmp_cache(integrator, alg, unwrap_cache(integrator, true))[2]
Expand All @@ -245,6 +251,8 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
tmp = jacobian_finitediff!(J, f, x, jac_config, integrator)
end
integrator.stats.nf += tmp
else
error("$alg_autodiff not yet supported in jacobian! function")
end
nothing
end
Expand Down Expand Up @@ -272,7 +280,8 @@ function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2,
end

sparsity, colorvec = sparsity_colorvec(f, u)
if alg_autodiff(alg)
# TODO: more generc, do we need this?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? I don't understand.

if alg_autodiff(alg) isa AutoForwardDiff
_chunksize = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg) # SparseDiffEq uses different convection...

T = if standardtag(alg)
Expand All @@ -282,7 +291,7 @@ function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2,
end
jac_config = ForwardColorJacCache(uf, uprev, _chunksize; colorvec = colorvec,
sparsity = sparsity, tag = T)
else
elseif alg_autodiff(alg) isa AutoFiniteDiff
if alg_difftype(alg) !== Val{:complex}
jac_config = FiniteDiff.JacobianCache(tmp, du1, du2, alg_difftype(alg),
colorvec = colorvec,
Expand All @@ -294,6 +303,8 @@ function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2,
colorvec = colorvec,
sparsity = sparsity)
end
else
error("$alg_autodiff not yet supported in build_jac_config function")
end
else
jac_config = nothing
Expand Down Expand Up @@ -343,7 +354,7 @@ end

function build_grad_config(alg, f::F1, tf::F2, du1, t) where {F1, F2}
if !DiffEqBase.has_tgrad(f)
if alg_autodiff(alg)
if alg_autodiff(alg) isa AutoForwardDiff
T = if standardtag(alg)
typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(du1)))
else
Expand All @@ -362,8 +373,10 @@ function build_grad_config(alg, f::F1, tf::F2, du1, t) where {F1, F2}
(ForwardDiff.Partials((one(eltype(du1)),)),)) .*
false)
end
else
elseif alg_autodiff(alg) isa AutoFiniteDiff
grad_config = FiniteDiff.GradientCache(du1, t, alg_difftype(alg))
else
error("$alg_autodiff not yet supported in build_grad_config function")
end
else
grad_config = nothing
Expand Down
13 changes: 4 additions & 9 deletions src/integrators/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,20 +234,15 @@ function resize_J_W!(cache, integrator, i)
nf = nlsolve_f(f, integrator.alg)
islin = f isa Union{ODEFunction, SplitFunction} && islinear(nf.f)
if !islin
if isa(cache.J, AbstractSciMLOperator)
if cache.J isa AbstractSciMLOperator
resize!(cache.J, i)
elseif f.jac_prototype !== nothing
J = similar(f.jac_prototype, i, i)
J = DiffEqArrayOperator(J; update_func = f.jac)
elseif cache.J isa SparseDiffTools.JacVec
resize!(cache.J.cache1, i)
resize!(cache.J.cache2, i)
resize!(cache.J.x, i)
end
if cache.W.jacvec !== nothing
resize!(cache.W.jacvec.cache1, i)
resize!(cache.W.jacvec.cache2, i)
resize!(cache.W.jacvec.x, i)
if cache.W.jacvec isa AbstractSciMLOperator
# TODO: resize! will need to be implemented upstream to handle previously handled cases
resize!(cache.W.jacvec, i)
end
gaurav-arya marked this conversation as resolved.
Show resolved Hide resolved
cache.W = WOperator{DiffEqBase.isinplace(integrator.sol.prob)}(f.mass_matrix,
integrator.dt,
Expand Down
8 changes: 5 additions & 3 deletions src/misc_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,13 @@ function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothi
# TODO: this ignores the add of the `f` count for add_steps!
if integrator isa SciMLBase.DEIntegrator && _alg.linsolve !== nothing &&
!LinearSolve.needs_concrete_A(_alg.linsolve) &&
linsolve.A isa WOperator && linsolve.A.J isa SparseDiffTools.JacVec
if alg_autodiff(_alg)
linsolve.A isa WOperator && linsolve.A.J isa AbstractSciMLOperator
if alg_autodiff(_alg) isa AutoForwardDiff
integrator.stats.nf += linres.iters
else
elseif alg_autodiff(_alg) isa AutoFiniteDiff
integrator.stats.nf += 2 * linres.iters
else
error("$alg_autodiff not yet supported in dolinsolve function")
end
end

Expand Down
1 change: 1 addition & 0 deletions src/perform_step/linear_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ function perform_step!(integrator, cache::CayleyEulerConstantCache, repeat_step
A = f.f
end

# TODO: this is not in place, think about this
L = update_coefficients(A, uprev, p, t)
gaurav-arya marked this conversation as resolved.
Show resolved Hide resolved
V = cay(L * dt)
u = V * uprev * transpose(V)
Expand Down