From 7e616926905da31c55ca03d6a01fe86616482af0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Dec 2023 09:25:47 -0500 Subject: [PATCH 1/7] Setup to handle adjoints --- Project.toml | 4 +++- src/LinearSolve.jl | 7 +++++++ src/adjoint.jl | 47 ++++++++++++++++++++++++++++++++++++++++++++++ src/common.jl | 9 ++++++--- 4 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 src/adjoint.jl diff --git a/Project.toml b/Project.toml index ba9907272..38448a728 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "2.21.1" +version = "2.22.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" @@ -16,6 +17,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 572e310e9..c9ef40d16 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -24,6 +24,8 @@ PrecompileTools.@recompile_invalidations begin using DocStringExtensions using EnumX using Requires + using Markdown + using ChainRulesCore import InteractiveUtils import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix @@ -43,6 +45,8 @@ PrecompileTools.@recompile_invalidations begin import Preferences end +const CRC = ChainRulesCore + if Preferences.@load_preference("LoadMKL_JLL", true) using MKL_jll const usemkl = MKL_jll.is_available() @@ -124,6 +128,7 @@ include("solve_function.jl") include("default.jl") include("init.jl") include("extension_algs.jl") +include("adjoint.jl") include("deprecated.jl") @generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization; @@ -236,4 +241,6 @@ export MetalLUFactorization export OperatorAssumptions, OperatorCondition +export LinearSolveAdjoint + end diff --git a/src/adjoint.jl b/src/adjoint.jl new file mode 100644 index 000000000..f0f73e10d --- /dev/null +++ b/src/adjoint.jl @@ -0,0 +1,47 @@ +# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr. +# TODO: Document the options in LinearSolveAdjoint + +@doc doc""" + LinearSolveAdjoint(; linsolve = nothing) + +Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as: + +```math +\begin{align} +A^T \lambda &= \partial x \\ +\partial A &= -\lambda x^T \\ +\partial b &= \lambda +\end{align} +``` + +For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf). + +## Choice of Linear Solver + +Note that in most cases, it makes sense to use the same linear solver for the adjoint as the +forward solve (this is done by keeping the linsolve as `nothing`). For example, if the +forward solve was performed via a Factorization, then we can reuse the factorization for the +adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a +specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient. +""" +@kwdef struct LinearSolveAdjoint{L} <: + SciMLBase.AbstractSensitivityAlgorithm{0, false, :central} + linsolve::L = nothing +end + +CRC.@non_differentiable SciMLBase.init(::LinearProblem, ::Any...) + +function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache) + sensealg = cache.sensealg + + # Decide if we need to cache the + + sol = solve!(cache) + function ∇solve!(∂sol) + @assert !cache.isfresh "`cache.A` has been updated between the forward and the reverse pass. This is not supported." + + ∂cache = NoTangent() + return NoTangent(), ∂cache + end + return sol, ∇solve! +end diff --git a/src/common.jl b/src/common.jl index b206598d5..a49213521 100644 --- a/src/common.jl +++ b/src/common.jl @@ -65,7 +65,7 @@ end __issquare(assump::OperatorAssumptions) = assump.issq __conditioning(assump::OperatorAssumptions) = assump.condition -mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq} +mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S} A::TA b::Tb u::Tu @@ -80,6 +80,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq} maxiters::Int verbose::Bool assumptions::OperatorAssumptions{issq} + sensealg::S end function Base.setproperty!(cache::LinearCache, name::Symbol, x) @@ -137,6 +138,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, Pl = IdentityOperator(size(prob.A)[1]), Pr = IdentityOperator(size(prob.A)[2]), assumptions = OperatorAssumptions(issquare(prob.A)), + sensealg = LinearSolveAdjoint(), kwargs...) @unpack A, b, u0, p = prob @@ -170,8 +172,9 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, Tc = typeof(cacheval) cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc, - typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_, - p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions) + typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), + typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, + maxiters, verbose, assumptions, sensealg) return cache end From 06c09a3683ab6f92d69ee6f67a52fca93d91004d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Dec 2023 15:05:57 -0500 Subject: [PATCH 2/7] Finish part of the implementation --- src/adjoint.jl | 74 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 9 deletions(-) diff --git a/src/adjoint.jl b/src/adjoint.jl index f0f73e10d..de0c2642d 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -1,5 +1,4 @@ # TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr. -# TODO: Document the options in LinearSolveAdjoint @doc doc""" LinearSolveAdjoint(; linsolve = nothing) @@ -29,19 +28,76 @@ specific structure distinct from ``A`` then passing in a `linsolve` will be more linsolve::L = nothing end -CRC.@non_differentiable SciMLBase.init(::LinearProblem, ::Any...) +function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem, + alg::SciMLLinearSolveAlgorithm, args...; kwargs...) + cache = init(prob, alg, args...; kwargs...) + function ∇init(∂cache) + ∂∅ = NoTangent() + ∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p) + ∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p) + return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...) + end + return cache, ∇init +end -function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache) - sensealg = cache.sensealg +function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...; + kwargs...) + (; A, b, sensealg) = cache - # Decide if we need to cache the + # Decide if we need to cache `A` and `b` for the reverse pass + if sensealg.linsolve === nothing + # We can reuse the factorization so no copy is needed + # Krylov Methods don't modify `A`, so it's safe to just reuse it + # No Copy is needed even for the default case + if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod || + alg isa DefaultLinearSolver) + A_ = cache.alias_A ? deepcopy(A) : A + end + else + error("Not Implemented Yet!!!") + end + + # Forward Solve + sol = solve!(cache, alg, args...; kwargs...) - sol = solve!(cache) function ∇solve!(∂sol) - @assert !cache.isfresh "`cache.A` has been updated between the forward and the reverse pass. This is not supported." + @assert !cache.isfresh "`cache.A` has been updated between the forward and the \ + reverse pass. This is not supported." + ∂u = ∂sol.u + if sensealg.linsolve === nothing + λ = if cache.cacheval isa Factorization + cache.cacheval' \ ∂u + elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization + first(cache.cacheval)' \ ∂u + elseif alg isa AbstractKrylovSubspaceMethod + invprob = LinearProblem(transpose(cache.A), ∂u) + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + elseif alg isa DefaultLinearSolver + LinearSolve.defaultalg_adjoint_eval(cache, ∂u) + else + invprob = LinearProblem(transpose(A_), ∂u) # We cached `A` + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + end + else + error("Not Implemented Yet!!!") + end + + ∂A = -λ * transpose(sol.u) + ∂b = λ + ∂∅ = NoTangent() - ∂cache = NoTangent() - return NoTangent(), ∂cache + ∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol, + cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg) + + return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...) end return sol, ∇solve! end + +function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...) + prob = LinearProblem(A, b, p) + function ∇prob(∂prob) + return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p + end + return prob, ∇prob +end From c153903de2c8f4c1fdc5e1d29b9f2f065fb57a02 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 15:48:24 -0500 Subject: [PATCH 3/7] Allow special solver for adjoint --- src/LinearSolve.jl | 1 - src/adjoint.jl | 64 +++++++++++++-------------- src/common.jl | 10 +++-- test/adjoint.jl | 105 +++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 142 insertions(+), 39 deletions(-) create mode 100644 test/adjoint.jl diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 1d507457e..ce27a5abf 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -23,7 +23,6 @@ PrecompileTools.@recompile_invalidations begin using FastLapackInterface using DocStringExtensions using EnumX - using Requires using Markdown using ChainRulesCore import InteractiveUtils diff --git a/src/adjoint.jl b/src/adjoint.jl index de0c2642d..3d46d8048 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -1,7 +1,7 @@ # TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr. @doc doc""" - LinearSolveAdjoint(; linsolve = nothing) + LinearSolveAdjoint(; linsolve = missing) Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as: @@ -18,53 +18,49 @@ For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoi ## Choice of Linear Solver Note that in most cases, it makes sense to use the same linear solver for the adjoint as the -forward solve (this is done by keeping the linsolve as `nothing`). For example, if the +forward solve (this is done by keeping the linsolve as `missing`). For example, if the forward solve was performed via a Factorization, then we can reuse the factorization for the adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient. """ @kwdef struct LinearSolveAdjoint{L} <: SciMLBase.AbstractSensitivityAlgorithm{0, false, :central} - linsolve::L = nothing + linsolve::L = missing end -function CRC.rrule(::typeof(SciMLBase.init), prob::LinearProblem, - alg::SciMLLinearSolveAlgorithm, args...; kwargs...) +function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem, + alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A( + alg, prob.A, prob.b), kwargs...) + # sol = solve(prob, alg, args...; kwargs...) cache = init(prob, alg, args...; kwargs...) - function ∇init(∂cache) - ∂∅ = NoTangent() - ∂p = prob.p isa SciMLBase.NullParameters ? prob.p : ProjectTo(prob.p)(∂cache.p) - ∂prob = LinearProblem(∂cache.A, ∂cache.b, ∂p) - return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...) - end - return cache, ∇init -end + (; A, sensealg) = cache -function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...; - kwargs...) - (; A, b, sensealg) = cache + @assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis." # Decide if we need to cache `A` and `b` for the reverse pass - if sensealg.linsolve === nothing + if sensealg.linsolve === missing # We can reuse the factorization so no copy is needed # Krylov Methods don't modify `A`, so it's safe to just reuse it # No Copy is needed even for the default case if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod || alg isa DefaultLinearSolver) - A_ = cache.alias_A ? deepcopy(A) : A + A_ = alias_A ? deepcopy(A) : A end else - error("Not Implemented Yet!!!") + if alg isa DefaultLinearSolver + A_ = deepcopy(A) + else + A_ = alias_A ? deepcopy(A) : A + end end - # Forward Solve - sol = solve!(cache, alg, args...; kwargs...) + sol = solve!(cache) + + function ∇linear_solve(∂sol) + ∂∅ = NoTangent() - function ∇solve!(∂sol) - @assert !cache.isfresh "`cache.A` has been updated between the forward and the \ - reverse pass. This is not supported." ∂u = ∂sol.u - if sensealg.linsolve === nothing + if sensealg.linsolve === missing λ = if cache.cacheval isa Factorization cache.cacheval' \ ∂u elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization @@ -79,25 +75,23 @@ function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...; solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u end else - error("Not Implemented Yet!!!") + invprob = LinearProblem(transpose(A_), ∂u) # We cached `A` + λ = solve( + invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u end ∂A = -λ * transpose(sol.u) ∂b = λ - ∂∅ = NoTangent() - - ∂cache = LinearCache(∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache.isfresh, ∂∅, ∂∅, cache.abstol, - cache.reltol, cache.maxiters, cache.verbose, cache.assumptions, cache.sensealg) + ∂prob = LinearProblem(∂A, ∂b, ∂∅) - return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...) + return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...) end - return sol, ∇solve! + + return sol, ∇linear_solve end function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...) prob = LinearProblem(A, b, p) - function ∇prob(∂prob) - return NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p - end + ∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p) return prob, ∇prob end diff --git a/src/common.jl b/src/common.jl index 523781ad6..0d1e419ff 100644 --- a/src/common.jl +++ b/src/common.jl @@ -180,11 +180,15 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, end function SciMLBase.solve(prob::LinearProblem, args...; kwargs...) - solve!(init(prob, nothing, args...; kwargs...)) + return solve(prob, nothing, args...; kwargs...) end -function SciMLBase.solve(prob::LinearProblem, - alg::Union{SciMLLinearSolveAlgorithm, Nothing}, +function SciMLBase.solve(prob::LinearProblem, ::Nothing, args...; + assump = OperatorAssumptions(issquare(prob.A)), kwargs...) + return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...) +end + +function SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) solve!(init(prob, alg, args...; kwargs...)) end diff --git a/test/adjoint.jl b/test/adjoint.jl new file mode 100644 index 000000000..ecc9714eb --- /dev/null +++ b/test/adjoint.jl @@ -0,0 +1,105 @@ +using Zygote, ForwardDiff +using LinearSolve, LinearAlgebra, Test +using FiniteDiff + +n = 4 +A = rand(n, n); +b1 = rand(n); + +function f(A, b1; alg = LUFactorization()) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + s1 = sol1.u + norm(s1) +end + +f(A, b1) # Uses BLAS + +dA, db1 = Zygote.gradient(f, A, b1) + +dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1)) + +@test dA ≈ dA2 +@test db1 ≈ db12 + +A = rand(n, n); +b1 = rand(n); + +_ff = (x, y) -> f(x, + y; + alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)) +_ff(copy(A), copy(b1)) + +dA, db1 = Zygote.gradient(_ff, copy(A), copy(b1)) + +dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1)) + +@test dA ≈ dA2 +@test db1 ≈ db12 + +function f3(A, b1, b2; alg = KrylovJL_GMRES()) + prob = LinearProblem(A, b1) + sol1 = solve(prob, alg) + prob = LinearProblem(A, b2) + sol2 = solve(prob, alg) + norm(sol1.u .+ sol2.u) +end + +dA, db1, db2 = Zygote.gradient(f3, A, b1, b1) + +#= Needs ForwardDiff rules +dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1)) + +@test dA ≈ dA2 atol=5e-5 +@test db1 ≈ db12 +@test db2 ≈ db22 +=# + +A = rand(n, n); +b1 = rand(n); +for alg in ( + LUFactorization(), + RFLUFactorization(), + KrylovJL_GMRES() +) + @show alg + function fb(b) + prob = LinearProblem(A, b) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fb(b1) + + fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec + @show fd_jac + + zyg_jac = Zygote.jacobian(fb, b1) |> first |> vec + @show zyg_jac + + @test zyg_jac≈fd_jac rtol=1e-4 + + function fA(A) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fA(A) + + fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec + @show fd_jac + + zyg_jac = Zygote.jacobian(fA, A) |> first |> vec + @show zyg_jac + + @test zyg_jac≈fd_jac rtol=1e-4 +end diff --git a/test/runtests.jl b/test/runtests.jl index ae12f0544..4994eba23 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Core" @time @safetestset "SparseVector b Tests" include("sparse_vector.jl") @time @safetestset "Default Alg Tests" include("default_algs.jl") @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") + @time @safetestset "Adjoint Sensitivity" include("adjoint.jl") @time @safetestset "Traits" include("traits.jl") @time @safetestset "BandedMatrices" include("banded.jl") @time @safetestset "Static Arrays" include("static_arrays.jl") From 34995f60f74a433482e233924bf81277f87b361b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 16:02:22 -0500 Subject: [PATCH 4/7] Add compat entries --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 251a9acae..1a1ae1833 100644 --- a/Project.toml +++ b/Project.toml @@ -66,6 +66,7 @@ ArrayInterface = "7.7" BandedMatrices = "1.5" BlockDiagonals = "0.1.42" CUDA = "5" +ChainRulesCore = "1.22" ConcreteStructs = "0.2.3" DocStringExtensions = "0.9.3" EnumX = "1.0.4" @@ -87,6 +88,7 @@ KrylovKit = "0.6" Libdl = "1.10" LinearAlgebra = "1.10" MPI = "0.20" +Markdown = "1.10" Metal = "0.5" MultiFloats = "1" Pardiso = "0.5" From 7c1f1b208bddc4cc481693b4b2e721714d3c0751 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 16:16:07 -0500 Subject: [PATCH 5/7] Fix HYPRE --- ext/LinearSolveHYPREExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/LinearSolveHYPREExt.jl b/ext/LinearSolveHYPREExt.jl index cef49b7a8..0a3dcb1cc 100644 --- a/ext/LinearSolveHYPREExt.jl +++ b/ext/LinearSolveHYPREExt.jl @@ -5,7 +5,7 @@ using HYPRE.LibHYPRE: HYPRE_Complex using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve, OperatorAssumptions, default_tol, init_cacheval, __issquare, - __conditioning + __conditioning, LinearSolveAdjoint using SciMLBase: LinearProblem, SciMLBase using UnPack: @unpack using Setfield: @set! @@ -68,6 +68,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm, Pl = LinearAlgebra.I, Pr = LinearAlgebra.I, assumptions = OperatorAssumptions(), + sensealg = LinearSolveAdjoint(), kwargs...) @unpack A, b, u0, p = prob @@ -89,10 +90,9 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm, cache = LinearCache{ typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc, typeof(Pl), typeof(Pr), typeof(reltol), - typeof(__issquare(assumptions)) + typeof(__issquare(assumptions), typeof(sensealg)) }(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, - maxiters, - verbose, assumptions) + maxiters, verbose, assumptions, sensealg) return cache end From 643271615838f6deb84ac0390bfe50f35ea68e9d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 16:28:14 -0500 Subject: [PATCH 6/7] More tests and some safety --- Project.toml | 4 +++- ext/LinearSolveHYPREExt.jl | 2 +- src/adjoint.jl | 6 +----- test/adjoint.jl | 31 +++++++++++++++++++++++++------ 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 1a1ae1833..9d6f212fc 100644 --- a/Project.toml +++ b/Project.toml @@ -110,6 +110,7 @@ StaticArrays = "1.5" StaticArraysCore = "1.4.2" Test = "1" UnPack = "1" +Zygote = "0.6.69" julia = "1.10" [extras] @@ -137,6 +138,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs"] +test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"] diff --git a/ext/LinearSolveHYPREExt.jl b/ext/LinearSolveHYPREExt.jl index 0a3dcb1cc..279aba75d 100644 --- a/ext/LinearSolveHYPREExt.jl +++ b/ext/LinearSolveHYPREExt.jl @@ -90,7 +90,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm, cache = LinearCache{ typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc, typeof(Pl), typeof(Pr), typeof(reltol), - typeof(__issquare(assumptions), typeof(sensealg)) + typeof(__issquare(assumptions)), typeof(sensealg) }(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) return cache diff --git a/src/adjoint.jl b/src/adjoint.jl index 3d46d8048..550bb2bd6 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -47,11 +47,7 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem, A_ = alias_A ? deepcopy(A) : A end else - if alg isa DefaultLinearSolver - A_ = deepcopy(A) - else - A_ = alias_A ? deepcopy(A) : A - end + A_ = deepcopy(A) end sol = solve!(cache) diff --git a/test/adjoint.jl b/test/adjoint.jl index ecc9714eb..26a72016f 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -51,15 +51,34 @@ end dA, db1, db2 = Zygote.gradient(f3, A, b1, b1) -#= Needs ForwardDiff rules -dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) -db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) -db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1)) +dA2 = FiniteDiff.finite_difference_gradient( + x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) +db12 = FiniteDiff.finite_difference_gradient( + x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) +db22 = FiniteDiff.finite_difference_gradient( + x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1)) + +@test dA≈dA2 atol=5e-5 +@test db1 ≈ db12 +@test db2 ≈ db22 + +function f4(A, b1, b2; alg = LUFactorization()) + prob = LinearProblem(A, b1) + sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR())) + prob = LinearProblem(A, b2) + sol2 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_GMRES())) + norm(sol1.u .+ sol2.u) +end + +dA, db1, db2 = Zygote.gradient(f4, A, b1, b1) + +dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1)) -@test dA ≈ dA2 atol=5e-5 +@test dA≈dA2 atol=5e-5 @test db1 ≈ db12 @test db2 ≈ db22 -=# A = rand(n, n); b1 = rand(n); From e937e675a82dbca2e45d5d05583d3d5557e22635 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Feb 2024 10:02:30 -0500 Subject: [PATCH 7/7] Up min SciMLBase compat --- Project.toml | 2 +- test/adjoint.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 9d6f212fc..51168952a 100644 --- a/Project.toml +++ b/Project.toml @@ -100,7 +100,7 @@ RecursiveArrayTools = "3.8" RecursiveFactorization = "0.2.14" Reexport = "1" SafeTestsets = "0.1" -SciMLBase = "2.23.0" +SciMLBase = "2.26.3" SciMLOperators = "0.3.7" Setfield = "1" SparseArrays = "1.10" diff --git a/test/adjoint.jl b/test/adjoint.jl index 26a72016f..4478daf98 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -52,11 +52,11 @@ end dA, db1, db2 = Zygote.gradient(f3, A, b1, b1) dA2 = FiniteDiff.finite_difference_gradient( - x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) + x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) db12 = FiniteDiff.finite_difference_gradient( - x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) + x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) db22 = FiniteDiff.finite_difference_gradient( - x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1)) + x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1)) @test dA≈dA2 atol=5e-5 @test db1 ≈ db12