Skip to content

Commit

Permalink
Setup to handle adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 20, 2023
1 parent 3b4f4ed commit 7e61692
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 4 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.21.1"
version = "2.22.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Expand All @@ -16,6 +17,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Expand Down
7 changes: 7 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ PrecompileTools.@recompile_invalidations begin
using DocStringExtensions
using EnumX
using Requires
using Markdown
using ChainRulesCore
import InteractiveUtils

import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix
Expand All @@ -43,6 +45,8 @@ PrecompileTools.@recompile_invalidations begin
import Preferences
end

const CRC = ChainRulesCore

if Preferences.@load_preference("LoadMKL_JLL", true)
using MKL_jll
const usemkl = MKL_jll.is_available()
Expand Down Expand Up @@ -124,6 +128,7 @@ include("solve_function.jl")
include("default.jl")
include("init.jl")
include("extension_algs.jl")
include("adjoint.jl")
include("deprecated.jl")

@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
Expand Down Expand Up @@ -236,4 +241,6 @@ export MetalLUFactorization

export OperatorAssumptions, OperatorCondition

export LinearSolveAdjoint

end
47 changes: 47 additions & 0 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.
# TODO: Document the options in LinearSolveAdjoint

@doc doc"""
LinearSolveAdjoint(; linsolve = nothing)
Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as:
```math
\begin{align}
A^T \lambda &= \partial x \\
\partial A &= -\lambda x^T \\
\partial b &= \lambda
\end{align}
```
For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf).
## Choice of Linear Solver
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
forward solve (this is done by keeping the linsolve as `nothing`). For example, if the
forward solve was performed via a Factorization, then we can reuse the factorization for the
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
"""
@kwdef struct LinearSolveAdjoint{L} <:
SciMLBase.AbstractSensitivityAlgorithm{0, false, :central}
linsolve::L = nothing
end

CRC.@non_differentiable SciMLBase.init(::LinearProblem, ::Any...)

function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache)
sensealg = cache.sensealg

# Decide if we need to cache the

sol = solve!(cache)
function ∇solve!(∂sol)
@assert !cache.isfresh "`cache.A` has been updated between the forward and the reverse pass. This is not supported."

∂cache = NoTangent()
return NoTangent(), ∂cache
end
return sol, ∇solve!
end
9 changes: 6 additions & 3 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ end
__issquare(assump::OperatorAssumptions) = assump.issq
__conditioning(assump::OperatorAssumptions) = assump.condition

mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S}
A::TA
b::Tb
u::Tu
Expand All @@ -80,6 +80,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
maxiters::Int
verbose::Bool
assumptions::OperatorAssumptions{issq}
sensealg::S
end

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
Expand Down Expand Up @@ -137,6 +138,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Pl = IdentityOperator(size(prob.A)[1]),
Pr = IdentityOperator(size(prob.A)[2]),
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
kwargs...)
@unpack A, b, u0, p = prob

Expand Down Expand Up @@ -170,8 +172,9 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Tc = typeof(cacheval)

cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_,
p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions)
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
end

Expand Down

0 comments on commit 7e61692

Please sign in to comment.