-
-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjoints for Linear Solve #449
Merged
Merged
Changes from 2 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
7e61692
Setup to handle adjoints
avik-pal 06c09a3
Finish part of the implementation
avik-pal 7671369
Merge branch 'main' of github.com:SciML/LinearSolve.jl into ap/adjoint
avik-pal c153903
Allow special solver for adjoint
avik-pal 34995f6
Add compat entries
avik-pal 7c1f1b2
Fix HYPRE
avik-pal 6432716
More tests and some safety
avik-pal e937e67
Up min SciMLBase compat
avik-pal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr. | ||
|
||
@doc doc""" | ||
LinearSolveAdjoint(; linsolve = nothing) | ||
|
||
Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as: | ||
|
||
```math | ||
\begin{align} | ||
A^T \lambda &= \partial x \\ | ||
\partial A &= -\lambda x^T \\ | ||
\partial b &= \lambda | ||
\end{align} | ||
``` | ||
|
||
For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf). | ||
|
||
## Choice of Linear Solver | ||
|
||
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the | ||
forward solve (this is done by keeping the linsolve as `nothing`). For example, if the | ||
forward solve was performed via a Factorization, then we can reuse the factorization for the | ||
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a | ||
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient. | ||
""" | ||
@kwdef struct LinearSolveAdjoint{L} <: | ||
SciMLBase.AbstractSensitivityAlgorithm{0, false, :central} | ||
linsolve::L = nothing | ||
end | ||
|
||
function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem, | ||
alg::SciMLLinearSolveAlgorithm, args...; kwargs...) | ||
cache = init(prob, alg, args...; kwargs...) | ||
function ∇init(∂cache) | ||
∂∅ = NoTangent() | ||
∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p) | ||
∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p) | ||
return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...) | ||
end | ||
return cache, ∇init | ||
end | ||
|
||
function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...; | ||
kwargs...) | ||
(; A, b, sensealg) = cache | ||
|
||
# Decide if we need to cache `A` and `b` for the reverse pass | ||
if sensealg.linsolve === nothing | ||
# We can reuse the factorization so no copy is needed | ||
# Krylov Methods don't modify `A`, so it's safe to just reuse it | ||
# No Copy is needed even for the default case | ||
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod || | ||
alg isa DefaultLinearSolver) | ||
A_ = cache.alias_A ? deepcopy(A) : A | ||
end | ||
else | ||
error("Not Implemented Yet!!!") | ||
end | ||
|
||
# Forward Solve | ||
sol = solve!(cache, alg, args...; kwargs...) | ||
|
||
function ∇solve!(∂sol) | ||
@assert !cache.isfresh "`cache.A` has been updated between the forward and the \ | ||
reverse pass. This is not supported." | ||
∂u = ∂sol.u | ||
if sensealg.linsolve === nothing | ||
λ = if cache.cacheval isa Factorization | ||
cache.cacheval' \ ∂u | ||
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization | ||
first(cache.cacheval)' \ ∂u | ||
elseif alg isa AbstractKrylovSubspaceMethod | ||
invprob = LinearProblem(transpose(cache.A), ∂u) | ||
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u | ||
elseif alg isa DefaultLinearSolver | ||
LinearSolve.defaultalg_adjoint_eval(cache, ∂u) | ||
else | ||
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A` | ||
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u | ||
end | ||
else | ||
error("Not Implemented Yet!!!") | ||
end | ||
|
||
∂A = -λ * transpose(sol.u) | ||
∂b = λ | ||
∂∅ = NoTangent() | ||
|
||
∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol, | ||
cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg) | ||
|
||
return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...) | ||
end | ||
return sol, ∇solve! | ||
end | ||
|
||
function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...) | ||
prob = LinearProblem(A, b, p) | ||
function ∇prob(∂prob) | ||
return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p | ||
end | ||
return prob, ∇prob | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we technically have to deepcopy in here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess so, it can be problematic if there are 2 subsequent solve calls on the cache.