Skip to content

Commit

Permalink
Allow special solver for adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 24, 2024
1 parent 7671369 commit c153903
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 39 deletions.
1 change: 0 additions & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ PrecompileTools.@recompile_invalidations begin
using FastLapackInterface
using DocStringExtensions
using EnumX
using Requires
using Markdown
using ChainRulesCore
import InteractiveUtils
Expand Down
64 changes: 29 additions & 35 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.

@doc doc"""
LinearSolveAdjoint(; linsolve = nothing)
LinearSolveAdjoint(; linsolve = missing)
Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as:
Expand All @@ -18,53 +18,49 @@ For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoi
## Choice of Linear Solver
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
forward solve (this is done by keeping the linsolve as `nothing`). For example, if the
forward solve (this is done by keeping the linsolve as `missing`). For example, if the
forward solve was performed via a Factorization, then we can reuse the factorization for the
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
"""
@kwdef struct LinearSolveAdjoint{L} <:
SciMLBase.AbstractSensitivityAlgorithm{0, false, :central}
linsolve::L = nothing
linsolve::L = missing
end

function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem,
alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A(
alg, prob.A, prob.b), kwargs...)
# sol = solve(prob, alg, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
function ∇init(∂cache)
∂∅ = NoTangent()
∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p)
∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p)
return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...)
end
return cache, ∇init
end
(; A, sensealg) = cache

function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...;
kwargs...)
(; A, b, sensealg) = cache
@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."

# Decide if we need to cache `A` and `b` for the reverse pass
if sensealg.linsolve === nothing
if sensealg.linsolve === missing
# We can reuse the factorization so no copy is needed
# Krylov Methods don't modify `A`, so it's safe to just reuse it
# No Copy is needed even for the default case
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
alg isa DefaultLinearSolver)
A_ = cache.alias_A ? deepcopy(A) : A
A_ = alias_A ? deepcopy(A) : A
end
else
error("Not Implemented Yet!!!")
if alg isa DefaultLinearSolver
A_ = deepcopy(A)
else
A_ = alias_A ? deepcopy(A) : A
end
end

# Forward Solve
sol = solve!(cache, alg, args...; kwargs...)
sol = solve!(cache)

function ∇linear_solve(∂sol)
∂∅ = NoTangent()

function ∇solve!(∂sol)
@assert !cache.isfresh "`cache.A` has been updated between the forward and the \
reverse pass. This is not supported."
∂u = ∂sol.u
if sensealg.linsolve === nothing
if sensealg.linsolve === missing
λ = if cache.cacheval isa Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
Expand All @@ -79,25 +75,23 @@ function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...;
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
end
else
error("Not Implemented Yet!!!")
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
λ = solve(
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

∂A = -λ * transpose(sol.u)
∂b = λ
∂∅ = NoTangent()

∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol,
cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg)
∂prob = LinearProblem(∂A, ∂b, ∂∅)

return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)
return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...)
end
return sol, ∇solve!

return sol, ∇linear_solve
end

function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
prob = LinearProblem(A, b, p)
function ∇prob(∂prob)
return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p
end
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
return prob, ∇prob
end
10 changes: 7 additions & 3 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,15 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
end

function SciMLBase.solve(prob::LinearProblem, args...; kwargs...)
solve!(init(prob, nothing, args...; kwargs...))
return solve(prob, nothing, args...; kwargs...)
end

function SciMLBase.solve(prob::LinearProblem,
alg::Union{SciMLLinearSolveAlgorithm, Nothing},
function SciMLBase.solve(prob::LinearProblem, ::Nothing, args...;
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...)
end

function SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
args...; kwargs...)
solve!(init(prob, alg, args...; kwargs...))
end
Expand Down
105 changes: 105 additions & 0 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using Zygote, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff

n = 4
A = rand(n, n);
b1 = rand(n);

function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

s1 = sol1.u
norm(s1)
end

f(A, b1) # Uses BLAS

dA, db1 = Zygote.gradient(f, A, b1)

dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))

@test dA dA2
@test db1 db12

A = rand(n, n);
b1 = rand(n);

_ff = (x, y) -> f(x,
y;
alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization))
_ff(copy(A), copy(b1))

dA, db1 = Zygote.gradient(_ff, copy(A), copy(b1))

dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))

@test dA dA2
@test db1 db12

function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg)
prob = LinearProblem(A, b2)
sol2 = solve(prob, alg)
norm(sol1.u .+ sol2.u)
end

dA, db1, db2 = Zygote.gradient(f3, A, b1, b1)

#= Needs ForwardDiff rules
dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1))
@test dA ≈ dA2 atol=5e-5
@test db1 ≈ db12
@test db2 ≈ db22
=#

A = rand(n, n);
b1 = rand(n);
for alg in (
LUFactorization(),
RFLUFactorization(),
KrylovJL_GMRES()
)
@show alg
function fb(b)
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fb(b1)

fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fd_jac

zyg_jac = Zygote.jacobian(fb, b1) |> first |> vec
@show zyg_jac

@test zyg_jacfd_jac rtol=1e-4

function fA(A)
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fA(A)

fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fd_jac

zyg_jac = Zygote.jacobian(fA, A) |> first |> vec
@show zyg_jac

@test zyg_jacfd_jac rtol=1e-4
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Core"
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
@time @safetestset "Default Alg Tests" include("default_algs.jl")
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
@time @safetestset "Traits" include("traits.jl")
@time @safetestset "BandedMatrices" include("banded.jl")
@time @safetestset "Static Arrays" include("static_arrays.jl")
Expand Down

0 comments on commit c153903

Please sign in to comment.