From 64f84a127a6109e276cf0793e56502c9b6ad994a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Dec 2023 11:26:39 -0500 Subject: [PATCH] Default to using SimpleGMRES for the backward pass --- Project.toml | 6 ++- docs/src/api.md | 38 +++++++++++++++++++ ...mNetworksLinearSolveSciMLSensitivityExt.jl | 18 +++++++++ ...pEquilibriumNetworksSciMLSensitivityExt.jl | 10 ----- 4 files changed, 60 insertions(+), 12 deletions(-) create mode 100644 ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl delete mode 100644 ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl diff --git a/Project.toml b/Project.toml index 531a0285..141813cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DeepEquilibriumNetworks" uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" authors = ["Avik Pal "] -version = "2.0.0" +version = "2.0.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -18,11 +18,12 @@ SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [weakdeps] +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -DeepEquilibriumNetworksSciMLSensitivityExt = "SciMLSensitivity" +DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"] DeepEquilibriumNetworksZygoteExt = "Zygote" [compat] @@ -32,6 +33,7 @@ ConcreteStructs = "0.2" ConstructionBase = "1" DiffEqBase = "6.119" LinearAlgebra = "1" +LinearSolve = "2.21.2" Lux = "0.5.11" Random = "1" SciMLBase = "2" diff --git a/docs/src/api.md b/docs/src/api.md index 3e852c87..0a2384db 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -9,6 +9,44 @@ To construct a continuous DEQ, any ODE solver compatible with `DifferentialEquat can be passed as the solver. To construct a discrete DEQ, any root finding algorithm compatible with `NonlinearSolve.jl` API can be passed as the solver. +## Choosing a Solver + +### Root Finding Algorithms + +Using Root Finding Algorithms give fast convergence when possible, but these methods also +tend to be unstable. If you must use a root finding algorithm, we recommend using: + + 1. `NewtonRaphson` or `TrustRegion` for small models + 2. `LimitedMemoryBroyden` for large Deep Learning applications (with well-conditioned + Jacobians) + 3. `NewtonRaphson(; linsolve = KrylovJL_GMRES())` for cases when Broyden methods fail + +Note that Krylov Methods rely on efficient VJPs which are not available for all Lux models. +If you think this is causing a performance regression, please open an issue in +[Lux.jl](https://github.com/LuxDL/Lux.jl). + +### ODE Solvers + +Using ODE Solvers give slower convergence, but are more stable. We generally recommend these +methods over root finding algorithms. If you use implicit ODE solvers, remember to use +Krylov linear solvers, see OrdinaryDiffEq.jl documentation for these. For most cases, we +recommend: + + 1. `VCAB3()` for high tolerance problems + 2. `Tsit5()` for high tolerance problems where `VCAB3()` fails + 3. In all other cases, follow the recommendation given in [OrdinaryDiffEq.jl](https://docs.sciml.ai/DiffEqDocs/stable/solvers/ode_solve/#ode_solve) documentation + +### Sensitivity Analysis + + 1. For `MultiScaleNeuralODE`, we default to `GaussAdjoint(; autojacvec = ZygoteVJP())`. A + faster alternative would be `BacksolveAdjoint(; autojacvec = ZygoteVJP())` but there are + stability concerns for using that. Follow the recommendation given in [SciMLSensitivity.jl](https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/#Choosing-a-Sensitivity-Algorithm) documentation. + 2. For Steady State Problems, we default to + `SteadyStateAdjoint(; linsolve = SimpleGMRES(; blocksize, linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)))`. + This default will perform poorly on small models. It is recommended to pass + `sensealg = SteadyStateAdjoint()` or + `sensealg = SteadyStateAdjoint(; linsolve = LUFactorization())` for small models. + ## Standard Models ```@docs diff --git a/ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl b/ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl new file mode 100644 index 00000000..21cc34ca --- /dev/null +++ b/ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl @@ -0,0 +1,18 @@ +module DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt + +# Linear Solve is a dependency of SciMLSensitivity, so we only need to load SciMLSensitivity +# to load this extension +using LinearSolve, SciMLBase, SciMLSensitivity +import DeepEquilibriumNetworks: __default_sensealg + +@inline function __default_sensealg(prob::SteadyStateProblem) + # We want to avoid the cost for cache construction for linsolve = nothing + # For small problems we should use concrete jacobian but we assume users want to solve + # large problems with this package so we default to GMRES and avoid runtime dispatches + linsolve = SimpleGMRES{true}(; blocksize=prod(size(prob.u0)[1:(end - 1)])) + linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3) + return SteadyStateAdjoint(; linsolve, linsolve_kwargs, autojacvec=ZygoteVJP()) +end +@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP()) + +end diff --git a/ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl b/ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl deleted file mode 100644 index be62b995..00000000 --- a/ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl +++ /dev/null @@ -1,10 +0,0 @@ -module DeepEquilibriumNetworksSciMLSensitivityExt - -using SciMLBase, SciMLSensitivity -import DeepEquilibriumNetworks: __default_sensealg - -@inline __default_sensealg(::SteadyStateProblem) = SteadyStateAdjoint(; - autojacvec=ZygoteVJP(), linsolve_kwargs=(; maxiters=10, abstol=1e-3, reltol=1e-3)) -@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP()) - -end