Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoints for Linear Solve #449

Merged
merged 8 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.24.0"
version = "2.25.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Expand All @@ -16,6 +17,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Expand Down Expand Up @@ -64,6 +66,7 @@ ArrayInterface = "7.7"
BandedMatrices = "1.5"
BlockDiagonals = "0.1.42"
CUDA = "5"
ChainRulesCore = "1.22"
ConcreteStructs = "0.2.3"
DocStringExtensions = "0.9.3"
EnumX = "1.0.4"
Expand All @@ -85,6 +88,7 @@ KrylovKit = "0.6"
Libdl = "1.10"
LinearAlgebra = "1.10"
MPI = "0.20"
Markdown = "1.10"
Metal = "0.5"
MultiFloats = "1"
Pardiso = "0.5"
Expand All @@ -96,7 +100,7 @@ RecursiveArrayTools = "3.8"
RecursiveFactorization = "0.2.14"
Reexport = "1"
SafeTestsets = "0.1"
SciMLBase = "2.23.0"
SciMLBase = "2.26.3"
SciMLOperators = "0.3.7"
Setfield = "1"
SparseArrays = "1.10"
Expand All @@ -106,6 +110,7 @@ StaticArrays = "1.5"
StaticArraysCore = "1.4.2"
Test = "1"
UnPack = "1"
Zygote = "0.6.69"
julia = "1.10"

[extras]
Expand Down Expand Up @@ -133,6 +138,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
8 changes: 4 additions & 4 deletions ext/LinearSolveHYPREExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using HYPRE.LibHYPRE: HYPRE_Complex
using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
OperatorAssumptions, default_tol, init_cacheval, __issquare,
__conditioning
__conditioning, LinearSolveAdjoint
using SciMLBase: LinearProblem, SciMLBase
using UnPack: @unpack
using Setfield: @set!
Expand Down Expand Up @@ -68,6 +68,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
Pl = LinearAlgebra.I,
Pr = LinearAlgebra.I,
assumptions = OperatorAssumptions(),
sensealg = LinearSolveAdjoint(),
kwargs...)
@unpack A, b, u0, p = prob

Expand All @@ -89,10 +90,9 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
cache = LinearCache{
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol),
typeof(__issquare(assumptions))
typeof(__issquare(assumptions)), typeof(sensealg)
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters,
verbose, assumptions)
maxiters, verbose, assumptions, sensealg)
return cache
end

Expand Down
7 changes: 7 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ PrecompileTools.@recompile_invalidations begin
using FastLapackInterface
using DocStringExtensions
using EnumX
using Markdown
using ChainRulesCore
import InteractiveUtils

import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix
Expand All @@ -42,6 +44,8 @@ PrecompileTools.@recompile_invalidations begin
import Preferences
end

const CRC = ChainRulesCore

if Preferences.@load_preference("LoadMKL_JLL", true)
using MKL_jll
const usemkl = MKL_jll.is_available()
Expand Down Expand Up @@ -125,6 +129,7 @@ include("solve_function.jl")
include("default.jl")
include("init.jl")
include("extension_algs.jl")
include("adjoint.jl")
include("deprecated.jl")

@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
Expand Down Expand Up @@ -240,4 +245,6 @@ export MetalLUFactorization

export OperatorAssumptions, OperatorCondition

export LinearSolveAdjoint

end
93 changes: 93 additions & 0 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.

@doc doc"""
LinearSolveAdjoint(; linsolve = missing)

Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as:

```math
\begin{align}
A^T \lambda &= \partial x \\
\partial A &= -\lambda x^T \\
\partial b &= \lambda
\end{align}
```

For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf).

## Choice of Linear Solver

Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
forward solve (this is done by keeping the linsolve as `missing`). For example, if the
forward solve was performed via a Factorization, then we can reuse the factorization for the
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
"""
@kwdef struct LinearSolveAdjoint{L} <:
SciMLBase.AbstractSensitivityAlgorithm{0, false, :central}
linsolve::L = missing
end

function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A(
alg, prob.A, prob.b), kwargs...)
# sol = solve(prob, alg, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
(; A, sensealg) = cache

@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."

# Decide if we need to cache `A` and `b` for the reverse pass
if sensealg.linsolve === missing
# We can reuse the factorization so no copy is needed
# Krylov Methods don't modify `A`, so it's safe to just reuse it
# No Copy is needed even for the default case
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
alg isa DefaultLinearSolver)
A_ = alias_A ? deepcopy(A) : A
end
else
A_ = deepcopy(A)
end

sol = solve!(cache)

function ∇linear_solve(∂sol)
∂∅ = NoTangent()

∂u = ∂sol.u
if sensealg.linsolve === missing
λ = if cache.cacheval isa Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
first(cache.cacheval)' \ ∂u
elseif alg isa AbstractKrylovSubspaceMethod
invprob = LinearProblem(transpose(cache.A), ∂u)
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
elseif alg isa DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
else
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
end
else
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
λ = solve(
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

∂A = -λ * transpose(sol.u)
∂b = λ
∂prob = LinearProblem(∂A, ∂b, ∂∅)

return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...)
end

return sol, ∇linear_solve
end

function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
prob = LinearProblem(A, b, p)
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
return prob, ∇prob
end
19 changes: 13 additions & 6 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ end
__issquare(assump::OperatorAssumptions) = assump.issq
__conditioning(assump::OperatorAssumptions) = assump.condition

mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S}
A::TA
b::Tb
u::Tu
Expand All @@ -80,6 +80,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
maxiters::Int
verbose::Bool
assumptions::OperatorAssumptions{issq}
sensealg::S
end

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
Expand Down Expand Up @@ -138,6 +139,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Pl = IdentityOperator(size(prob.A)[1]),
Pr = IdentityOperator(size(prob.A)[2]),
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
kwargs...)
@unpack A, b, u0, p = prob

Expand Down Expand Up @@ -171,17 +173,22 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Tc = typeof(cacheval)

cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_,
p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions)
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
end

function SciMLBase.solve(prob::LinearProblem, args...; kwargs...)
solve!(init(prob, nothing, args...; kwargs...))
return solve(prob, nothing, args...; kwargs...)
end

function SciMLBase.solve(prob::LinearProblem,
alg::Union{SciMLLinearSolveAlgorithm, Nothing},
function SciMLBase.solve(prob::LinearProblem, ::Nothing, args...;
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...)
end

function SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
args...; kwargs...)
solve!(init(prob, alg, args...; kwargs...))
end
Expand Down
16 changes: 10 additions & 6 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,26 +779,30 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs.
cacheval.colptr &&
SparseArrays.decrement(SparseArrays.getrowval(A)) ==
cacheval.rowval)
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)), check=false)
fact = lu(
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)),
check = false)
else
fact = lu!(cacheval,
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)), check=false)
nonzeros(A)), check = false)
end
else
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), check=false)
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
end
cache.cacheval = fact
cache.isfresh = false
end

F = @get_cacheval(cache, :UMFPACKFactorization)
F = @get_cacheval(cache, :UMFPACKFactorization)
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
y = ldiv!(cache.u, F, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
else
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache; retcode=ReturnCode.Infeasible)
SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
end
end

Expand Down
Loading
Loading