Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SparseDiffTools v2 for steadystateadjoint #808

Merged
merged 8 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Christopher Rackauckas <[email protected]>", "Yingbo Ma <
version = "7.37.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
Expand Down Expand Up @@ -71,7 +72,7 @@ ReverseDiff = "1.9"
SciMLBase = "1.66.0"
SciMLOperators = "0.1, 0.2, 0.3"
SimpleNonlinearSolve = "0.1.8"
SparseDiffTools = "2"
SparseDiffTools = "2.4"
StaticArraysCore = "1.4"
StochasticDiffEq = "6.20"
Tracker = "0.2"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import ArrayInterface
import Enzyme
import GPUArraysCore
using StaticArraysCore
using ADTypes
using SparseDiffTools
using SciMLOperators
import TruncatedStacktraces
Expand Down
5 changes: 5 additions & 0 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1313,3 +1313,8 @@ struct ForwardDiffOverAdjoint{A} <:
AbstractSecondOrderSensitivityAlgorithm{nothing, true, nothing}
adjalg::A
end

get_autodiff_from_vjp(vjp::ReverseDiffVJP{compile}) where{compile} = AutoReverseDiff(; compile = compile)
get_autodiff_from_vjp(::ZygoteVJP) = AutoZygote()
get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme()
get_autodiff_from_vjp(::TrackerVJP) = AutoTracker()
29 changes: 5 additions & 24 deletions src/steadystate_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ end
dgdu === nothing && dgdp === nothing && g === nothing &&
error("Either `dgdu`, `dgdp`, or `g` must be specified.")

needs_jac = if has_adjoint(f)
false
# TODO: What is the correct heuristic? Can we afford to compute Jacobian for
# cases where the length(u0) > 50 and if yes till what threshold
needs_jac = if sensealg.linsolve === nothing
elseif sensealg.linsolve === nothing
length(u0) <= 50
else
LinearSolve.needs_concrete_A(sensealg.linsolve)
Expand Down Expand Up @@ -98,8 +100,8 @@ end
end

if !needs_jac
# TODO: FixedVecJacOperator should respect the `autojacvec` of the algorithm
operator = FixedVecJacOperator(f, y, p, Val(DiffEqBase.isinplace(sol.prob)))
# operator = VecJac(f, y, p; Val(DiffEqBase.isinplace(sol.prob)))
operator = VecJac(f, y, p; autodiff = get_autodiff_from_vjp(vjp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really equivalent. IIRC VecJac recomputes the pullback everytime a call to mul! is made. In this case, we have a fixed input, only the seeding changes so we compute the pullback once and just reevaluate it multiple times.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually fixed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ))
else
linear_problem = LinearProblem(diffcache.J', vec(dgdu_val'); u0 = vec(λ))
Expand Down Expand Up @@ -136,24 +138,3 @@ end
return vjp
end
end

function FixedVecJacOperator(f_in, y, p, ::Val{false})
# NOTE: Zygote doesn't support inplace
input, f = Zygote.pullback(x -> f_in(x, p, nothing), y)
output = f(input)[1]
function f_operator!(du, u, p, t)
du .= vec(f(reshape(u, size(input)))[1])
return du
end
op = FunctionOperator(f_operator!, vec(input), vec(output))
return op
end

function FixedVecJacOperator(f, y, p, ::Val{true})
function f_operator!(du, u, p, t)
num_vecjac!(du, (_du, _u) -> f(reshape(_du, size(y)), _u, p, t),
y, reshape(u, size(y)))
return du
end
return FunctionOperator(f_operator!, vec(y); p)
end