Skip to content

Commit

Permalink
Finish part of the implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 20, 2023
1 parent 7e61692 commit 06c09a3
Showing 1 changed file with 65 additions and 9 deletions.
74 changes: 65 additions & 9 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.
# TODO: Document the options in LinearSolveAdjoint

@doc doc"""
LinearSolveAdjoint(; linsolve = nothing)
Expand Down Expand Up @@ -29,19 +28,76 @@ specific structure distinct from ``A`` then passing in a `linsolve` will be more
linsolve::L = nothing
end

CRC.@non_differentiable SciMLBase.init(::LinearProblem, ::Any...)
function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem,

Check warning on line 31 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L31

Added line #L31 was not covered by tests
alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
function ∇init(∂cache)
∂∅ = NoTangent()
∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p)
∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p)
return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...)

Check warning on line 38 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L33-L38

Added lines #L33 - L38 were not covered by tests
end
return cache, ∇init

Check warning on line 40 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L40

Added line #L40 was not covered by tests
end

function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache)
sensealg = cache.sensealg
function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...;

Check warning on line 43 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L43

Added line #L43 was not covered by tests
kwargs...)
(; A, b, sensealg) = cache

Check warning on line 45 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L45

Added line #L45 was not covered by tests

# Decide if we need to cache the
# Decide if we need to cache `A` and `b` for the reverse pass
if sensealg.linsolve === nothing

Check warning on line 48 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L48

Added line #L48 was not covered by tests
# We can reuse the factorization so no copy is needed
# Krylov Methods don't modify `A`, so it's safe to just reuse it
# No Copy is needed even for the default case
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||

Check warning on line 52 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L52

Added line #L52 was not covered by tests
alg isa DefaultLinearSolver)
A_ = cache.alias_A ? deepcopy(A) : A

Check warning on line 54 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L54

Added line #L54 was not covered by tests
end
else
error("Not Implemented Yet!!!")

Check warning on line 57 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L57

Added line #L57 was not covered by tests
end

# Forward Solve
sol = solve!(cache, alg, args...; kwargs...)

Check warning on line 61 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L61

Added line #L61 was not covered by tests

sol = solve!(cache)
function ∇solve!(∂sol)
@assert !cache.isfresh "`cache.A` has been updated between the forward and the reverse pass. This is not supported."
@assert !cache.isfresh "`cache.A` has been updated between the forward and the \

Check warning on line 64 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
reverse pass. This is not supported."
∂u = ∂sol.u
if sensealg.linsolve === nothing
λ = if cache.cacheval isa Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
first(cache.cacheval)' \ ∂u
elseif alg isa AbstractKrylovSubspaceMethod
invprob = LinearProblem(transpose(cache.A), ∂u)
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
elseif alg isa DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)

Check warning on line 76 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L66-L76

Added lines #L66 - L76 were not covered by tests
else
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u

Check warning on line 79 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L78-L79

Added lines #L78 - L79 were not covered by tests
end
else
error("Not Implemented Yet!!!")

Check warning on line 82 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L82

Added line #L82 was not covered by tests
end

∂A = -λ * transpose(sol.u)
∂b = λ
∂∅ = NoTangent()

Check warning on line 87 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L85-L87

Added lines #L85 - L87 were not covered by tests

∂cache = NoTangent()
return NoTangent(), ∂cache
∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol,

Check warning on line 89 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L89

Added line #L89 was not covered by tests
cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg)

return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)

Check warning on line 92 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L92

Added line #L92 was not covered by tests
end
return sol, ∇solve!

Check warning on line 94 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L94

Added line #L94 was not covered by tests
end

function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
prob = LinearProblem(A, b, p)
function ∇prob(∂prob)
return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p

Check warning on line 100 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L97-L100

Added lines #L97 - L100 were not covered by tests
end
return prob, ∇prob

Check warning on line 102 in src/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/adjoint.jl#L102

Added line #L102 was not covered by tests
end

0 comments on commit 06c09a3

Please sign in to comment.