Skip to content

Commit

Permalink
Merge pull request #406 from SciML/enzyme_default
Browse files Browse the repository at this point in the history
handle the default algorithm with enzyme adjoints
  • Loading branch information
ChrisRackauckas authored Oct 28, 2023
2 parents 6f3f2cd + 0263399 commit dc9a3b4
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 1 deletion.
4 changes: 3 additions & 1 deletion ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,12 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
elseif _linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod
# Doesn't modify `A`, so it's safe to just reuse it
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
solve(invprob;
solve(invprob, _linearsolve.alg;
abstol = _linsolve.val.abstol,
reltol = _linsolve.val.reltol,
verbose = _linsolve.val.verbose)
elseif _linsolve.alg isa LinearSolve.DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(_linsolve, dy)
else
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end
Expand Down
58 changes: 58 additions & 0 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,61 @@ end
end
ex = Expr(:if, ex.args...)
end

"""
```
elseif DefaultAlgorithmChoice.LUFactorization === cache.alg
(cache.cacheval.LUFactorization)' \\ dy
else
...
end
```
"""
@generated function defaultalg_adjoint_eval(cache::LinearCache, dy)
ex = :()
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
newex = if alg in Symbol.((DefaultAlgorithmChoice.MKLLUFactorization,
DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
DefaultAlgorithmChoice.RFLUFactorization))
quote
getproperty(cache.cacheval,$(Meta.quot(alg)))[1]' \ dy
end
elseif alg in Symbol.((DefaultAlgorithmChoice.LUFactorization,
DefaultAlgorithmChoice.QRFactorization,
DefaultAlgorithmChoice.KLUFactorization,
DefaultAlgorithmChoice.UMFPACKFactorization,
DefaultAlgorithmChoice.LDLtFactorization,
DefaultAlgorithmChoice.SparspakFactorization,
DefaultAlgorithmChoice.BunchKaufmanFactorization,
DefaultAlgorithmChoice.CHOLMODFactorization,
DefaultAlgorithmChoice.SVDFactorization,
DefaultAlgorithmChoice.CholeskyFactorization,
DefaultAlgorithmChoice.NormalCholeskyFactorization,
DefaultAlgorithmChoice.QRFactorizationPivoted,
DefaultAlgorithmChoice.GenericLUFactorization))
quote
getproperty(cache.cacheval,$(Meta.quot(alg)))' \ dy
end
elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,))
quote
invprob = LinearSolve.LinearProblem(transpose(cache.A), dy)
solve(invprob, cache.alg;
abstol = cache.val.abstol,
reltol = cache.val.reltol,
verbose = cache.val.verbose)
end
else
quote
error("Default linear solver with algorithm $(alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end
end

ex = if ex == :()
Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex,
:(error("Algorithm Choice not Allowed")))
else
Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex, ex)
end
end
ex = Expr(:if, ex.args...)
end
17 changes: 17 additions & 0 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,30 @@ db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
@test dA dA2
@test db1 db12

A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);

_ff = (x,y) -> f(x,y; alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization))
_ff(copy(A), copy(b1))

Enzyme.autodiff(Reverse, (x,y) -> f(x,y; alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)), Duplicated(copy(A), dA), Duplicated(copy(b1), db1))

dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))

@test dA dA2
@test db1 db12

A = rand(n, n);
dA = zeros(n, n);
dA2 = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
db12 = zeros(n);


# Batch test
n = 4
A = rand(n, n);
Expand Down

0 comments on commit dc9a3b4

Please sign in to comment.