Skip to content

Commit

Permalink
Add forward enzyme rules for init and solve
Browse files Browse the repository at this point in the history
  • Loading branch information
sharanry committed Nov 5, 2023
1 parent d863895 commit e89be05
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 4 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,18 @@ ConcreteStructs = "0.2"
DocStringExtensions = "0.8, 0.9"
EnumX = "1"
EnzymeCore = "0.5, 0.6"
EnzymeTestUtils = "0.1"
FastLapackInterface = "1, 2"
GPUArraysCore = "0.1"
HYPRE = "1.4.0"
InteractiveUtils = "1.6"
IterativeSolvers = "0.9.3"
Libdl = "1.6"
LinearAlgebra = "1.6"
KLU = "0.3.0, 0.4"
KernelAbstractions = "0.9"
Krylov = "0.9"
KrylovKit = "0.5, 0.6"
Libdl = "1.6"
LinearAlgebra = "1.6"
MKL = "0.6"
PrecompileTools = "1"
Preferences = "1"
Expand All @@ -96,6 +97,7 @@ julia = "1.9"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
Expand All @@ -114,4 +116,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "EnzymeTestUtils", "FiniteDiff", "BandedMatrices"]
41 changes: 41 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,47 @@ using Enzyme

using EnzymeCore

function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@assert !(prob isa Const)
res = func.val(prob.val, alg.val; kwargs...)
if RT <: Const
return res

Check warning on line 16 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L12-L16

Added lines #L12 - L16 were not covered by tests
end
dres = func.val(prob.dval, alg.val; kwargs...)
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A
if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated
return Duplicated(res, dres)

Check warning on line 24 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L18-L24

Added lines #L18 - L24 were not covered by tests
end
end

function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
@assert !(linsolve isa Const)

Check warning on line 29 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L28-L29

Added lines #L28 - L29 were not covered by tests

A = deepcopy(linsolve.val.A) #mutates after function is applied
res = func.val(linsolve.val; kwargs...)

Check warning on line 32 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L31-L32

Added lines #L31 - L32 were not covered by tests

if RT <: Const
return res

Check warning on line 35 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
end

dres = deepcopy(res)
invA = inv(A)
db = linsolve.dval.b
dA = linsolve.dval.A
dres.u .= invA * (db - dA * res.u)

Check warning on line 42 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L38-L42

Added lines #L38 - L42 were not covered by tests

if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated
return Duplicated(res, dres)

Check warning on line 47 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L44-L47

Added lines #L44 - L47 were not covered by tests
end

return Duplicated(res, dres)

Check warning on line 50 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L50

Added line #L50 was not covered by tests
end

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
Expand Down
58 changes: 57 additions & 1 deletion test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff

n = 4
A = rand(n, n);
Expand Down Expand Up @@ -161,4 +162,59 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
@test dA ≈ dA2 atol=5e-5
@test db1 ≈ db12
@test db2 ≈ db22
=#
=#


A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
function fb(b; alg = LUFactorization())
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fb(b1)

manual_jac = map(onehot(b1)) do db
y = A \ b1
sum(inv(A) * (db - dA*y))
end |> collect
@show manual_jac

fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fd_jac

en_jac = map(onehot(b1)) do db1
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
eres[1]
end |> collect
@show en_jac

@test_broken en_jac manual_jac
@test_broken en_jac fd_jac

function fA(A; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fA(A)

manual_jac = map(onehot(A)) do dA
y = A \ b1
sum(inv(A) * (db1 - dA*y))
end |> collect

fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec

en_jac = map(onehot(A)) do dA
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
eres[1]
end |> collect

@test_broken en_jac manual_jac
@test_broken en_jac fd_jac

0 comments on commit e89be05

Please sign in to comment.