From d31d4cfe839330a6639efd8cf700331111f9c051 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Wed, 22 Mar 2023 19:00:33 -0400 Subject: [PATCH 1/6] bump sparsedifftools --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c6b4c1a20..45eb94211 100644 --- a/Project.toml +++ b/Project.toml @@ -71,7 +71,7 @@ ReverseDiff = "1.9" SciMLBase = "1.66.0" SciMLOperators = "0.1, 0.2" SimpleNonlinearSolve = "0.1.8" -SparseDiffTools = "1, 2" +SparseDiffTools = "2" StaticArraysCore = "1.4" StochasticDiffEq = "6.20" Tracker = "0.2" From c872e14d1417d14d3d9bd5d7af3eee00597e2055 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Wed, 22 Mar 2023 19:15:13 -0400 Subject: [PATCH 2/6] get_autodiff_from_vjp --- src/SciMLSensitivity.jl | 1 + src/sensitivity_algorithms.jl | 5 +++++ src/steadystate_adjoint.jl | 1 + 3 files changed, 7 insertions(+) diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 83ec4f3be..7f8e6ed99 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -14,6 +14,7 @@ import ArrayInterface import Enzyme import GPUArraysCore using StaticArraysCore +using ADTypes using SparseDiffTools using SciMLOperators import TruncatedStacktraces diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index d691264ee..40f9f1732 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1193,3 +1193,8 @@ struct ForwardDiffOverAdjoint{A} <: AbstractSecondOrderSensitivityAlgorithm{nothing, true, nothing} adjalg::A end + +get_autodiff_from_vjp(vjp::ReverseDiffVJP{compile}) where{compile} = AutoReverseDiff(; compile = compile) +get_autodiff_from_vjp(::ZygoteVJP) = AutoZygote() +get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme() +get_autodiff_from_vjp(::TrackerVJP) = AutoTracker() diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 1d6b0f589..1bd58e5e0 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -99,6 +99,7 @@ end if !needs_jac # TODO: FixedVecJacOperator should respect the `autojacvec` of the algorithm + #operator = VecJac(f, y, p; autodiff = ) operator = FixedVecJacOperator(f, y, p, Val(DiffEqBase.isinplace(sol.prob))) linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ)) else From bca0805ee24e90293fe44d9ceb5e190cb6e05a3a Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Wed, 22 Mar 2023 19:15:52 -0400 Subject: [PATCH 3/6] (1) branch on has_adjoint, (2) rm fixedjacvecop, (3) replace with VecJac --- src/steadystate_adjoint.jl | 29 ++++------------------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 1bd58e5e0..bdc3be189 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -49,9 +49,11 @@ end dgdu === nothing && dgdp === nothing && g === nothing && error("Either `dgdu`, `dgdp`, or `g` must be specified.") + needs_jac = if has_adjoint(f) + false # TODO: What is the correct heuristic? Can we afford to compute Jacobian for # cases where the length(u0) > 50 and if yes till what threshold - needs_jac = if sensealg.linsolve === nothing + elseif sensealg.linsolve === nothing length(u0) <= 50 else LinearSolve.needs_concrete_A(sensealg.linsolve) @@ -98,9 +100,7 @@ end end if !needs_jac - # TODO: FixedVecJacOperator should respect the `autojacvec` of the algorithm - #operator = VecJac(f, y, p; autodiff = ) - operator = FixedVecJacOperator(f, y, p, Val(DiffEqBase.isinplace(sol.prob))) + operator = VecJac(f, y, p; autodiff = get_autodiff_from_vjp(vjp)) linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ)) else linear_problem = LinearProblem(diffcache.J', vec(dgdu_val'); u0 = vec(λ)) @@ -137,24 +137,3 @@ end return vjp end end - -function FixedVecJacOperator(f_in, y, p, ::Val{false}) - # NOTE: Zygote doesn't support inplace - input, f = Zygote.pullback(x -> f_in(x, p, nothing), y) - output = f(input)[1] - function f_operator!(du, u, p, t) - du .= vec(f(reshape(u, size(input)))[1]) - return du - end - op = FunctionOperator(f_operator!, vec(input), vec(output)) - return op -end - -function FixedVecJacOperator(f, y, p, ::Val{true}) - function f_operator!(du, u, p, t) - num_vecjac!(du, (_du, _u) -> f(reshape(_du, size(y)), _u, p, t), - y, reshape(u, size(y))) - return du - end - return FunctionOperator(f_operator!, vec(y); p) -end From cb30f44b966b7b225839d464be2cc813c4339a73 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Tue, 30 May 2023 14:03:50 -0400 Subject: [PATCH 4/6] adtypes dep --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 7d76fbdc5..7e1e73c9e 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Christopher Rackauckas ", "Yingbo Ma < version = "7.31.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" Cassette = "7057c7e9-c182-5462-911a-8362d720325c" From a0f217b261368c72aa51225ba858d4f59347495a Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Fri, 2 Jun 2023 10:41:38 -0400 Subject: [PATCH 5/6] bump sparsedifftools compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d39802307..9bb5b228e 100644 --- a/Project.toml +++ b/Project.toml @@ -72,7 +72,7 @@ ReverseDiff = "1.9" SciMLBase = "1.66.0" SciMLOperators = "0.1, 0.2, 0.3" SimpleNonlinearSolve = "0.1.8" -SparseDiffTools = "2" +SparseDiffTools = "2.4" StaticArraysCore = "1.4" StochasticDiffEq = "6.20" Tracker = "0.2" From 9f065153b6d1940b6f3687b0c0f9b06d5a9b1753 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Tue, 22 Aug 2023 09:33:18 -0400 Subject: [PATCH 6/6] fix typo --- src/steadystate_adjoint.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 6f8eef9b5..af73558a2 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -100,8 +100,8 @@ end end if !needs_jac - operator = - (f, y, p; autodiff = get_autodiff_from_vjp(vjp)) + # operator = VecJac(f, y, p; Val(DiffEqBase.isinplace(sol.prob))) + operator = VecJac(f, y, p; autodiff = get_autodiff_from_vjp(vjp)) linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ)) else linear_problem = LinearProblem(diffcache.J', vec(dgdu_val'); u0 = vec(λ))