diff --git a/src/adjoint.jl b/src/adjoint.jl index f0f73e10d..de0c2642d 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -1,5 +1,4 @@ # TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr. -# TODO: Document the options in LinearSolveAdjoint @doc doc""" LinearSolveAdjoint(; linsolve = nothing) @@ -29,19 +28,76 @@ specific structure distinct from ``A`` then passing in a `linsolve` will be more linsolve::L = nothing end -CRC.@non_differentiable SciMLBase.init(::LinearProblem, ::Any...) +function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem, + alg::SciMLLinearSolveAlgorithm, args...; kwargs...) + cache = init(prob, alg, args...; kwargs...) + function ∇init(∂cache) + ∂∅ = NoTangent() + ∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p) + ∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p) + return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...) + end + return cache, ∇init +end -function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache) - sensealg = cache.sensealg +function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...; + kwargs...) + (; A, b, sensealg) = cache - # Decide if we need to cache the + # Decide if we need to cache `A` and `b` for the reverse pass + if sensealg.linsolve === nothing + # We can reuse the factorization so no copy is needed + # Krylov Methods don't modify `A`, so it's safe to just reuse it + # No Copy is needed even for the default case + if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod || + alg isa DefaultLinearSolver) + A_ = cache.alias_A ? deepcopy(A) : A + end + else + error("Not Implemented Yet!!!") + end + + # Forward Solve + sol = solve!(cache, alg, args...; kwargs...) - sol = solve!(cache) function ∇solve!(∂sol) - @assert !cache.isfresh "`cache.A` has been updated between the forward and the reverse pass. This is not supported." + @assert !cache.isfresh "`cache.A` has been updated between the forward and the \ + reverse pass. This is not supported." + ∂u = ∂sol.u + if sensealg.linsolve === nothing + λ = if cache.cacheval isa Factorization + cache.cacheval' \ ∂u + elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization + first(cache.cacheval)' \ ∂u + elseif alg isa AbstractKrylovSubspaceMethod + invprob = LinearProblem(transpose(cache.A), ∂u) + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + elseif alg isa DefaultLinearSolver + LinearSolve.defaultalg_adjoint_eval(cache, ∂u) + else + invprob = LinearProblem(transpose(A_), ∂u) # We cached `A` + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + end + else + error("Not Implemented Yet!!!") + end + + ∂A = -λ * transpose(sol.u) + ∂b = λ + ∂∅ = NoTangent() - ∂cache = NoTangent() - return NoTangent(), ∂cache + ∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol, + cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg) + + return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...) end return sol, ∇solve! end + +function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...) + prob = LinearProblem(A, b, p) + function ∇prob(∂prob) + return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p + end + return prob, ∇prob +end