diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 1d507457e..ce27a5abf 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -23,7 +23,6 @@ PrecompileTools.@recompile_invalidations begin using FastLapackInterface using DocStringExtensions using EnumX - using Requires using Markdown using ChainRulesCore import InteractiveUtils diff --git a/src/adjoint.jl b/src/adjoint.jl index de0c2642d..3d46d8048 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -1,7 +1,7 @@ # TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr. @doc doc""" - LinearSolveAdjoint(; linsolve = nothing) + LinearSolveAdjoint(; linsolve = missing) Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as: @@ -18,53 +18,49 @@ For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoi ## Choice of Linear Solver Note that in most cases, it makes sense to use the same linear solver for the adjoint as the -forward solve (this is done by keeping the linsolve as `nothing`). For example, if the +forward solve (this is done by keeping the linsolve as `missing`). For example, if the forward solve was performed via a Factorization, then we can reuse the factorization for the adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient. """ @kwdef struct LinearSolveAdjoint{L} <: SciMLBase.AbstractSensitivityAlgorithm{0, false, :central} - linsolve::L = nothing + linsolve::L = missing end -function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem, - alg::SciMLLinearSolveAlgorithm, args...; kwargs...) +function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem, + alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A( + alg, prob.A, prob.b), kwargs...) + # sol = solve(prob, alg, args...; kwargs...) cache = init(prob, alg, args...; kwargs...) - function ∇init(∂cache) - ∂∅ = NoTangent() - ∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p) - ∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p) - return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...) - end - return cache, ∇init -end + (; A, sensealg) = cache -function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...; - kwargs...) - (; A, b, sensealg) = cache + @assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis." # Decide if we need to cache `A` and `b` for the reverse pass - if sensealg.linsolve === nothing + if sensealg.linsolve === missing # We can reuse the factorization so no copy is needed # Krylov Methods don't modify `A`, so it's safe to just reuse it # No Copy is needed even for the default case if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod || alg isa DefaultLinearSolver) - A_ = cache.alias_A ? deepcopy(A) : A + A_ = alias_A ? deepcopy(A) : A end else - error("Not Implemented Yet!!!") + if alg isa DefaultLinearSolver + A_ = deepcopy(A) + else + A_ = alias_A ? deepcopy(A) : A + end end - # Forward Solve - sol = solve!(cache, alg, args...; kwargs...) + sol = solve!(cache) + + function ∇linear_solve(∂sol) + ∂∅ = NoTangent() - function ∇solve!(∂sol) - @assert !cache.isfresh "`cache.A` has been updated between the forward and the \ - reverse pass. This is not supported." ∂u = ∂sol.u - if sensealg.linsolve === nothing + if sensealg.linsolve === missing λ = if cache.cacheval isa Factorization cache.cacheval' \ ∂u elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization @@ -79,25 +75,23 @@ function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...; solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u end else - error("Not Implemented Yet!!!") + invprob = LinearProblem(transpose(A_), ∂u) # We cached `A` + λ = solve( + invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u end ∂A = -λ * transpose(sol.u) ∂b = λ - ∂∅ = NoTangent() - - ∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol, - cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg) + ∂prob = LinearProblem(∂A, ∂b, ∂∅) - return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...) + return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...) end - return sol, ∇solve! + + return sol, ∇linear_solve end function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...) prob = LinearProblem(A, b, p) - function ∇prob(∂prob) - return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p - end + ∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p) return prob, ∇prob end diff --git a/src/common.jl b/src/common.jl index 523781ad6..0d1e419ff 100644 --- a/src/common.jl +++ b/src/common.jl @@ -180,11 +180,15 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, end function SciMLBase.solve(prob::LinearProblem, args...; kwargs...) - solve!(init(prob, nothing, args...; kwargs...)) + return solve(prob, nothing, args...; kwargs...) end -function SciMLBase.solve(prob::LinearProblem, - alg::Union{SciMLLinearSolveAlgorithm, Nothing}, +function SciMLBase.solve(prob::LinearProblem, ::Nothing, args...; + assump = OperatorAssumptions(issquare(prob.A)), kwargs...) + return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...) +end + +function SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) solve!(init(prob, alg, args...; kwargs...)) end diff --git a/test/adjoint.jl b/test/adjoint.jl new file mode 100644 index 000000000..ecc9714eb --- /dev/null +++ b/test/adjoint.jl @@ -0,0 +1,105 @@ +using Zygote, ForwardDiff +using LinearSolve, LinearAlgebra, Test +using FiniteDiff + +n = 4 +A = rand(n, n); +b1 = rand(n); + +function f(A, b1; alg = LUFactorization()) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + s1 = sol1.u + norm(s1) +end + +f(A, b1) # Uses BLAS + +dA, db1 = Zygote.gradient(f, A, b1) + +dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1)) + +@test dA ≈ dA2 +@test db1 ≈ db12 + +A = rand(n, n); +b1 = rand(n); + +_ff = (x, y) -> f(x, + y; + alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)) +_ff(copy(A), copy(b1)) + +dA, db1 = Zygote.gradient(_ff, copy(A), copy(b1)) + +dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1)) + +@test dA ≈ dA2 +@test db1 ≈ db12 + +function f3(A, b1, b2; alg = KrylovJL_GMRES()) + prob = LinearProblem(A, b1) + sol1 = solve(prob, alg) + prob = LinearProblem(A, b2) + sol2 = solve(prob, alg) + norm(sol1.u .+ sol2.u) +end + +dA, db1, db2 = Zygote.gradient(f3, A, b1, b1) + +#= Needs ForwardDiff rules +dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1)) + +@test dA ≈ dA2 atol=5e-5 +@test db1 ≈ db12 +@test db2 ≈ db22 +=# + +A = rand(n, n); +b1 = rand(n); +for alg in ( + LUFactorization(), + RFLUFactorization(), + KrylovJL_GMRES() +) + @show alg + function fb(b) + prob = LinearProblem(A, b) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fb(b1) + + fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec + @show fd_jac + + zyg_jac = Zygote.jacobian(fb, b1) |> first |> vec + @show zyg_jac + + @test zyg_jac≈fd_jac rtol=1e-4 + + function fA(A) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fA(A) + + fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec + @show fd_jac + + zyg_jac = Zygote.jacobian(fA, A) |> first |> vec + @show zyg_jac + + @test zyg_jac≈fd_jac rtol=1e-4 +end diff --git a/test/runtests.jl b/test/runtests.jl index ae12f0544..4994eba23 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Core" @time @safetestset "SparseVector b Tests" include("sparse_vector.jl") @time @safetestset "Default Alg Tests" include("default_algs.jl") @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") + @time @safetestset "Adjoint Sensitivity" include("adjoint.jl") @time @safetestset "Traits" include("traits.jl") @time @safetestset "BandedMatrices" include("banded.jl") @time @safetestset "Static Arrays" include("static_arrays.jl")