From aabc85afa489366e3dda61365ebd656a1010b024 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 28 Oct 2023 14:08:24 +0100 Subject: [PATCH 1/2] handle the default algorithm with enzyme adjoints Fixes https://github.com/SciML/LinearSolve.jl/issues/403 --- ext/LinearSolveEnzymeExt.jl | 4 ++- src/default.jl | 58 +++++++++++++++++++++++++++++++++++++ test/enzyme.jl | 17 +++++++++++ 3 files changed, 78 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index a25fe14ba..7cc613bb4 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -147,10 +147,12 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s elseif _linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod # Doesn't modify `A`, so it's safe to just reuse it invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy) - solve(invprob; + solve(invprob, _linearsolve.alg; abstol = _linsolve.val.abstol, reltol = _linsolve.val.reltol, verbose = _linsolve.val.verbose) + elseif _linsolve.alg isa LinearSolve.DefaultLinearSolver + LinearSolve.defaultalg_adjoint_eval(_linsolve, dy) else error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") end diff --git a/src/default.jl b/src/default.jl index 82166a247..86e730222 100644 --- a/src/default.jl +++ b/src/default.jl @@ -362,3 +362,61 @@ end end ex = Expr(:if, ex.args...) end + +""" +``` +elseif DefaultAlgorithmChoice.LUFactorization === cache.alg + (cache.cacheval.LUFactorization)' \\ dy +else + ... +end +``` +""" +@generated function defaultalg_adjoint_eval(cache::LinearCache, dy) + ex = :() + for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) + newex = if alg in Symbol.((DefaultAlgorithmChoice.MKLLUFactorization, + DefaultAlgorithmChoice.AppleAccelerateLUFactorization, + DefaultAlgorithmChoice.RFLUFactorization)) + quote + getproperty(cache.cacheval,$(Meta.quot(alg)))[1]' \ dy + end + elseif alg in Symbol.((DefaultAlgorithmChoice.LUFactorization, + DefaultAlgorithmChoice.QRFactorization, + DefaultAlgorithmChoice.KLUFactorization, + DefaultAlgorithmChoice.UMFPACKFactorization, + DefaultAlgorithmChoice.LDLtFactorization, + DefaultAlgorithmChoice.SparspakFactorization, + DefaultAlgorithmChoice.BunchKaufmanFactorization, + DefaultAlgorithmChoice.CHOLMODFactorization, + DefaultAlgorithmChoice.SVDFactorization, + DefaultAlgorithmChoice.CholeskyFactorization, + DefaultAlgorithmChoice.NormalCholeskyFactorization, + DefaultAlgorithmChoice.QRFactorizationPivoted, + DefaultAlgorithmChoice.GenericLUFactorization)) + quote + getproperty(cache.cacheval,$(Meta.quot(alg)))' \ dy + end + elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,)) + quote + invprob = LinearSolve.LinearProblem(transpose(cache.A), dy) + solve(invprob, cache.alg; + abstol = cache.val.abstol, + reltol = cache.val.reltol, + verbose = cache.val.verbose) + end + else + quote + error("Default linear solver with algorithm $(alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") + end + end + + ex = if ex == :() + Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg), newex, + :(error("Algorithm Choice not Allowed"))) + else + Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg), newex, ex) + end + end + ex = Expr(:if, ex.args...) +end \ No newline at end of file diff --git a/test/enzyme.jl b/test/enzyme.jl index a194450d0..02e071d41 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -26,6 +26,22 @@ db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) @test dA ≈ dA2 @test db1 ≈ db12 +A = rand(n, n); +dA = zeros(n, n); +b1 = rand(n); +db1 = zeros(n); + +_ff = (x,y) -> f(x,y; alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)) +_ff(copy(A), copy(b1)) + +Enzyme.autodiff(Reverse, (x,y) -> f(x,y; alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)), Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) + +dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) + +@test dA ≈ dA2 +@test db1 ≈ db12 + A = rand(n, n); dA = zeros(n, n); dA2 = zeros(n, n); @@ -33,6 +49,7 @@ b1 = rand(n); db1 = zeros(n); db12 = zeros(n); + # Batch test n = 4 A = rand(n, n); From 0263399da77dd4299925e4b4465c029e2369ef94 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 28 Oct 2023 14:10:22 +0100 Subject: [PATCH 2/2] forgot to save --- src/default.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/default.jl b/src/default.jl index 86e730222..05a46020a 100644 --- a/src/default.jl +++ b/src/default.jl @@ -412,10 +412,10 @@ end end ex = if ex == :() - Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg), newex, + Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex, :(error("Algorithm Choice not Allowed"))) else - Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg), newex, ex) + Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex, ex) end end ex = Expr(:if, ex.args...)