Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoints for Linear Solve #449

Merged
merged 8 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.21.1"
version = "2.22.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Expand All @@ -16,6 +17,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Expand Down
7 changes: 7 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ PrecompileTools.@recompile_invalidations begin
using DocStringExtensions
using EnumX
using Requires
using Markdown
using ChainRulesCore
import InteractiveUtils

import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix
Expand All @@ -43,6 +45,8 @@ PrecompileTools.@recompile_invalidations begin
import Preferences
end

const CRC = ChainRulesCore

if Preferences.@load_preference("LoadMKL_JLL", true)
using MKL_jll
const usemkl = MKL_jll.is_available()
Expand Down Expand Up @@ -124,6 +128,7 @@ include("solve_function.jl")
include("default.jl")
include("init.jl")
include("extension_algs.jl")
include("adjoint.jl")
include("deprecated.jl")

@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
Expand Down Expand Up @@ -236,4 +241,6 @@ export MetalLUFactorization

export OperatorAssumptions, OperatorCondition

export LinearSolveAdjoint

end
103 changes: 103 additions & 0 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.

@doc doc"""
LinearSolveAdjoint(; linsolve = nothing)

Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as:

```math
\begin{align}
A^T \lambda &= \partial x \\
\partial A &= -\lambda x^T \\
\partial b &= \lambda
\end{align}
```

For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf).

## Choice of Linear Solver

Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
forward solve (this is done by keeping the linsolve as `nothing`). For example, if the
forward solve was performed via a Factorization, then we can reuse the factorization for the
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
"""
@kwdef struct LinearSolveAdjoint{L} <:
SciMLBase.AbstractSensitivityAlgorithm{0, false, :central}
linsolve::L = nothing
end

function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem,

Check warning on line 31 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L31

Added line #L31 was not covered by tests
alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
function ∇init(∂cache)
∂∅ = NoTangent()
∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p)
∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p)
return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...)

Check warning on line 38 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L33-L38

Added lines #L33 - L38 were not covered by tests
end
return cache, ∇init

Check warning on line 40 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L40

Added line #L40 was not covered by tests
end

function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...;

Check warning on line 43 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L43

Added line #L43 was not covered by tests
kwargs...)
(; A, b, sensealg) = cache

Check warning on line 45 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L45

Added line #L45 was not covered by tests

# Decide if we need to cache `A` and `b` for the reverse pass
if sensealg.linsolve === nothing

Check warning on line 48 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L48

Added line #L48 was not covered by tests
# We can reuse the factorization so no copy is needed
# Krylov Methods don't modify `A`, so it's safe to just reuse it
# No Copy is needed even for the default case
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||

Check warning on line 52 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L52

Added line #L52 was not covered by tests
alg isa DefaultLinearSolver)
A_ = cache.alias_A ? deepcopy(A) : A

Check warning on line 54 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L54

Added line #L54 was not covered by tests
end
else
error("Not Implemented Yet!!!")

Check warning on line 57 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L57

Added line #L57 was not covered by tests
end

# Forward Solve
sol = solve!(cache, alg, args...; kwargs...)

Check warning on line 61 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L61

Added line #L61 was not covered by tests

function ∇solve!(∂sol)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we technically have to deepcopy in here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so, it can be problematic if there are 2 subsequent solve calls on the cache.

@assert !cache.isfresh "`cache.A` has been updated between the forward and the \

Check warning on line 64 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
reverse pass. This is not supported."
∂u = ∂sol.u
if sensealg.linsolve === nothing
λ = if cache.cacheval isa Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
first(cache.cacheval)' \ ∂u
elseif alg isa AbstractKrylovSubspaceMethod
invprob = LinearProblem(transpose(cache.A), ∂u)
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
elseif alg isa DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)

Check warning on line 76 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L66-L76

Added lines #L66 - L76 were not covered by tests
else
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u

Check warning on line 79 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L78-L79

Added lines #L78 - L79 were not covered by tests
end
else
error("Not Implemented Yet!!!")

Check warning on line 82 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L82

Added line #L82 was not covered by tests
end

∂A = -λ * transpose(sol.u)
∂b = λ
∂∅ = NoTangent()

Check warning on line 87 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L85-L87

Added lines #L85 - L87 were not covered by tests

∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol,

Check warning on line 89 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L89

Added line #L89 was not covered by tests
cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg)

return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)

Check warning on line 92 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L92

Added line #L92 was not covered by tests
end
return sol, ∇solve!

Check warning on line 94 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L94

Added line #L94 was not covered by tests
end

function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
prob = LinearProblem(A, b, p)
function ∇prob(∂prob)
return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p

Check warning on line 100 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L97-L100

Added lines #L97 - L100 were not covered by tests
end
return prob, ∇prob

Check warning on line 102 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L102

Added line #L102 was not covered by tests
end
9 changes: 6 additions & 3 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ end
__issquare(assump::OperatorAssumptions) = assump.issq
__conditioning(assump::OperatorAssumptions) = assump.condition

mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S}
A::TA
b::Tb
u::Tu
Expand All @@ -80,6 +80,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
maxiters::Int
verbose::Bool
assumptions::OperatorAssumptions{issq}
sensealg::S
end

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
Expand Down Expand Up @@ -137,6 +138,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Pl = IdentityOperator(size(prob.A)[1]),
Pr = IdentityOperator(size(prob.A)[2]),
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
kwargs...)
@unpack A, b, u0, p = prob

Expand Down Expand Up @@ -170,8 +172,9 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Tc = typeof(cacheval)

cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_,
p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions)
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
end

Expand Down
Loading