From 7e616926905da31c55ca03d6a01fe86616482af0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Dec 2023 09:25:47 -0500 Subject: [PATCH] Setup to handle adjoints --- Project.toml | 4 +++- src/LinearSolve.jl | 7 +++++++ src/adjoint.jl | 47 ++++++++++++++++++++++++++++++++++++++++++++++ src/common.jl | 9 ++++++--- 4 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 src/adjoint.jl diff --git a/Project.toml b/Project.toml index ba9907272..38448a728 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "2.21.1" +version = "2.22.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" @@ -16,6 +17,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 572e310e9..c9ef40d16 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -24,6 +24,8 @@ PrecompileTools.@recompile_invalidations begin using DocStringExtensions using EnumX using Requires + using Markdown + using ChainRulesCore import InteractiveUtils import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix @@ -43,6 +45,8 @@ PrecompileTools.@recompile_invalidations begin import Preferences end +const CRC = ChainRulesCore + if Preferences.@load_preference("LoadMKL_JLL", true) using MKL_jll const usemkl = MKL_jll.is_available() @@ -124,6 +128,7 @@ include("solve_function.jl") include("default.jl") include("init.jl") include("extension_algs.jl") +include("adjoint.jl") include("deprecated.jl") @generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization; @@ -236,4 +241,6 @@ export MetalLUFactorization export OperatorAssumptions, OperatorCondition +export LinearSolveAdjoint + end diff --git a/src/adjoint.jl b/src/adjoint.jl new file mode 100644 index 000000000..f0f73e10d --- /dev/null +++ b/src/adjoint.jl @@ -0,0 +1,47 @@ +# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr. +# TODO: Document the options in LinearSolveAdjoint + +@doc doc""" + LinearSolveAdjoint(; linsolve = nothing) + +Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as: + +```math +\begin{align} +A^T \lambda &= \partial x \\ +\partial A &= -\lambda x^T \\ +\partial b &= \lambda +\end{align} +``` + +For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf). + +## Choice of Linear Solver + +Note that in most cases, it makes sense to use the same linear solver for the adjoint as the +forward solve (this is done by keeping the linsolve as `nothing`). For example, if the +forward solve was performed via a Factorization, then we can reuse the factorization for the +adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a +specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient. +""" +@kwdef struct LinearSolveAdjoint{L} <: + SciMLBase.AbstractSensitivityAlgorithm{0, false, :central} + linsolve::L = nothing +end + +CRC.@non_differentiable SciMLBase.init(::LinearProblem, ::Any...) + +function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache) + sensealg = cache.sensealg + + # Decide if we need to cache the + + sol = solve!(cache) + function ∇solve!(∂sol) + @assert !cache.isfresh "`cache.A` has been updated between the forward and the reverse pass. This is not supported." + + ∂cache = NoTangent() + return NoTangent(), ∂cache + end + return sol, ∇solve! +end diff --git a/src/common.jl b/src/common.jl index b206598d5..a49213521 100644 --- a/src/common.jl +++ b/src/common.jl @@ -65,7 +65,7 @@ end __issquare(assump::OperatorAssumptions) = assump.issq __conditioning(assump::OperatorAssumptions) = assump.condition -mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq} +mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S} A::TA b::Tb u::Tu @@ -80,6 +80,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq} maxiters::Int verbose::Bool assumptions::OperatorAssumptions{issq} + sensealg::S end function Base.setproperty!(cache::LinearCache, name::Symbol, x) @@ -137,6 +138,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, Pl = IdentityOperator(size(prob.A)[1]), Pr = IdentityOperator(size(prob.A)[2]), assumptions = OperatorAssumptions(issquare(prob.A)), + sensealg = LinearSolveAdjoint(), kwargs...) @unpack A, b, u0, p = prob @@ -170,8 +172,9 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, Tc = typeof(cacheval) cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc, - typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_, - p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions) + typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), + typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, + maxiters, verbose, assumptions, sensealg) return cache end