From d5e64bd9389232ab89e17fce9ded293772231f62 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 5 Nov 2023 23:45:18 +0530 Subject: [PATCH 01/11] Add forward enzyme rules for init and solve --- Project.toml | 4 ++- ext/LinearSolveEnzymeExt.jl | 41 ++++++++++++++++++++++++++ test/enzyme.jl | 58 ++++++++++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index e1b0e4204..f11951a80 100644 --- a/Project.toml +++ b/Project.toml @@ -66,6 +66,7 @@ DocStringExtensions = "0.9" EnumX = "1" EnzymeCore = "0.6" FastLapackInterface = "2" +EnzymeTestUtils = "0.1" GPUArraysCore = "0.1" HYPRE = "1.4.0" InteractiveUtils = "1.6" @@ -96,6 +97,7 @@ julia = "1.9" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" @@ -114,4 +116,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices"] +test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "EnzymeTestUtils", "FiniteDiff", "BandedMatrices"] diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 26d530d19..ba6b4dd9f 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -9,6 +9,47 @@ using Enzyme using EnzymeCore +function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} + @assert !(prob isa Const) + res = func.val(prob.val, alg.val; kwargs...) + if RT <: Const + return res + end + dres = func.val(prob.dval, alg.val; kwargs...) + dres.b .= res.b == dres.b ? zero(dres.b) : dres.b + dres.A .= res.A == dres.A ? zero(dres.A) : dres.A + if RT <: DuplicatedNoNeed + return dres + elseif RT <: Duplicated + return Duplicated(res, dres) + end +end + +function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} + @assert !(linsolve isa Const) + + A = deepcopy(linsolve.val.A) #mutates after function is applied + res = func.val(linsolve.val; kwargs...) + + if RT <: Const + return res + end + + dres = deepcopy(res) + invA = inv(A) + db = linsolve.dval.b + dA = linsolve.dval.A + dres.u .= invA * (db - dA * res.u) + + if RT <: DuplicatedNoNeed + return dres + elseif RT <: Duplicated + return Duplicated(res, dres) + end + + return Duplicated(res, dres) +end + function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} res = func.val(prob.val, alg.val; kwargs...) dres = if EnzymeRules.width(config) == 1 diff --git a/test/enzyme.jl b/test/enzyme.jl index 02e071d41..470d4a3a2 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,5 +1,6 @@ using Enzyme, ForwardDiff using LinearSolve, LinearAlgebra, Test +using FiniteDiff n = 4 A = rand(n, n); @@ -161,4 +162,59 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), @test dA ≈ dA2 atol=5e-5 @test db1 ≈ db12 @test db2 ≈ db22 -=# \ No newline at end of file +=# + + +A = rand(n, n); +dA = zeros(n, n); +b1 = rand(n); +function fb(b; alg = LUFactorization()) + prob = LinearProblem(A, b) + + sol1 = solve(prob, alg) + + sum(sol1.u) +end +fb(b1) + +manual_jac = map(onehot(b1)) do db + y = A \ b1 + sum(inv(A) * (db - dA*y)) +end |> collect +@show manual_jac + +fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec +@show fd_jac + +en_jac = map(onehot(b1)) do db1 + eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1)) + eres[1] +end |> collect +@show en_jac + +@test_broken en_jac ≈ manual_jac +@test_broken en_jac ≈ fd_jac + +function fA(A; alg = LUFactorization()) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + sum(sol1.u) +end +fA(A) + +manual_jac = map(onehot(A)) do dA + y = A \ b1 + sum(inv(A) * (db1 - dA*y)) +end |> collect + +fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec + +en_jac = map(onehot(A)) do dA + eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA)) + eres[1] +end |> collect + +@test_broken en_jac ≈ manual_jac +@test_broken en_jac ≈ fd_jac \ No newline at end of file From 6cfe69a5fae092e1d82970df33ac6916ed3afcf2 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 5 Nov 2023 23:47:17 +0530 Subject: [PATCH 02/11] Remove EnzymeTestUtils as it is not used --- Project.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index f11951a80..65b9d873e 100644 --- a/Project.toml +++ b/Project.toml @@ -97,7 +97,6 @@ julia = "1.9" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" @@ -116,4 +115,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "EnzymeTestUtils", "FiniteDiff", "BandedMatrices"] +test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices"] From f7b7c14cf35cf7234f081a520ba1d1b1bd6c608c Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 5 Nov 2023 23:59:54 +0530 Subject: [PATCH 03/11] Show computed jacobians in tests --- test/enzyme.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/enzyme.jl b/test/enzyme.jl index 470d4a3a2..d31a0fec4 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -208,13 +208,16 @@ manual_jac = map(onehot(A)) do dA y = A \ b1 sum(inv(A) * (db1 - dA*y)) end |> collect +@show manual_jac fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec +@show fd_jac en_jac = map(onehot(A)) do dA eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA)) eres[1] end |> collect +@show en_jac @test_broken en_jac ≈ manual_jac @test_broken en_jac ≈ fd_jac \ No newline at end of file From 8b8f8c434991cd129be48c04faa6f3970f4fb8bf Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 6 Nov 2023 01:06:23 +0530 Subject: [PATCH 04/11] Avoid inv --- ext/LinearSolveEnzymeExt.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index ba6b4dd9f..03d687c94 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -36,10 +36,9 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, end dres = deepcopy(res) - invA = inv(A) db = linsolve.dval.b dA = linsolve.dval.A - dres.u .= invA * (db - dA * res.u) + dres.u .= A \ (db - dA * res.u) if RT <: DuplicatedNoNeed return dres From 447fb61b4aee254918b84455d01b527f9658ab1b Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 6 Nov 2023 01:31:01 +0530 Subject: [PATCH 05/11] Avoid refactorization and cleanup tests to allow for numerical errors due to summation ordering --- ext/LinearSolveEnzymeExt.jl | 13 +++++++++---- test/enzyme.jl | 18 ++---------------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 03d687c94..6194d6230 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -28,17 +28,22 @@ end function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} @assert !(linsolve isa Const) - A = deepcopy(linsolve.val.A) #mutates after function is applied + linsolve = deepcopy(linsolve) #mutates after function is applied res = func.val(linsolve.val; kwargs...) if RT <: Const return res end - - dres = deepcopy(res) + + b = deepcopy(linsolve.val.b) + db = linsolve.dval.b dA = linsolve.dval.A - dres.u .= A \ (db - dA * res.u) + + linsolve.val.b = db - dA * res.u + dres = func.val(linsolve.val; kwargs...) + + linsolve.val.b = b if RT <: DuplicatedNoNeed return dres diff --git a/test/enzyme.jl b/test/enzyme.jl index d31a0fec4..334399daa 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -177,12 +177,6 @@ function fb(b; alg = LUFactorization()) end fb(b1) -manual_jac = map(onehot(b1)) do db - y = A \ b1 - sum(inv(A) * (db - dA*y)) -end |> collect -@show manual_jac - fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec @show fd_jac @@ -192,8 +186,7 @@ en_jac = map(onehot(b1)) do db1 end |> collect @show en_jac -@test_broken en_jac ≈ manual_jac -@test_broken en_jac ≈ fd_jac +@test en_jac ≈ fd_jac atol=1e-6 function fA(A; alg = LUFactorization()) prob = LinearProblem(A, b1) @@ -204,12 +197,6 @@ function fA(A; alg = LUFactorization()) end fA(A) -manual_jac = map(onehot(A)) do dA - y = A \ b1 - sum(inv(A) * (db1 - dA*y)) -end |> collect -@show manual_jac - fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec @show fd_jac @@ -219,5 +206,4 @@ en_jac = map(onehot(A)) do dA end |> collect @show en_jac -@test_broken en_jac ≈ manual_jac -@test_broken en_jac ≈ fd_jac \ No newline at end of file +@test en_jac ≈ fd_jac atol=1e-6 \ No newline at end of file From 59edb4e8012eb3854acea5df04aae6d28588e23f Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 6 Nov 2023 01:33:04 +0530 Subject: [PATCH 06/11] Remove unneeded deepcopy --- ext/LinearSolveEnzymeExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 6194d6230..0196886de 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -28,7 +28,6 @@ end function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} @assert !(linsolve isa Const) - linsolve = deepcopy(linsolve) #mutates after function is applied res = func.val(linsolve.val; kwargs...) if RT <: Const From 60afb5d57cb96a23fbedf793d02fa37d6740dfd7 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 6 Nov 2023 22:47:12 +0530 Subject: [PATCH 07/11] Switch to reltol from abstol for tests --- test/enzyme.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index 334399daa..e71c63357 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -186,7 +186,7 @@ en_jac = map(onehot(b1)) do db1 end |> collect @show en_jac -@test en_jac ≈ fd_jac atol=1e-6 +@test en_jac ≈ fd_jac rtol=1e-6 function fA(A; alg = LUFactorization()) prob = LinearProblem(A, b1) @@ -206,4 +206,4 @@ en_jac = map(onehot(A)) do dA end |> collect @show en_jac -@test en_jac ≈ fd_jac atol=1e-6 \ No newline at end of file +@test en_jac ≈ fd_jac rtol=1e-6 \ No newline at end of file From c14aeb20c2160a9875509ff729ef56494a9d6917 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 6 Nov 2023 22:59:54 +0530 Subject: [PATCH 08/11] Add tests for other algs and handle cases of algs currently not supported --- ext/LinearSolveEnzymeExt.jl | 4 ++- test/enzyme.jl | 66 +++++++++++++++++++++---------------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 0196886de..4075f940b 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -33,7 +33,9 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, if RT <: Const return res end - + if linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod + error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") + end b = deepcopy(linsolve.val.b) db = linsolve.dval.b diff --git a/test/enzyme.jl b/test/enzyme.jl index e71c63357..96c4d28d7 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,6 +1,7 @@ using Enzyme, ForwardDiff using LinearSolve, LinearAlgebra, Test using FiniteDiff +using SafeTestsets n = 4 A = rand(n, n); @@ -164,46 +165,53 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), @test db2 ≈ db22 =# - A = rand(n, n); dA = zeros(n, n); b1 = rand(n); -function fb(b; alg = LUFactorization()) - prob = LinearProblem(A, b) +for alg in ( + LUFactorization(), + RFLUFactorization(), + # KrylovJL_GMRES(), fails + ) + alg_str = string(alg) + @show alg_str + function fb(b) + prob = LinearProblem(A, b) - sol1 = solve(prob, alg) + sol1 = solve(prob, alg) - sum(sol1.u) -end -fb(b1) + sum(sol1.u) + end + fb(b1) -fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec -@show fd_jac + fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec + @show fd_jac -en_jac = map(onehot(b1)) do db1 - eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1)) - eres[1] -end |> collect -@show en_jac + en_jac = map(onehot(b1)) do db1 + eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1)) + eres[1] + end |> collect + @show en_jac -@test en_jac ≈ fd_jac rtol=1e-6 + @test en_jac ≈ fd_jac rtol=1e-6 -function fA(A; alg = LUFactorization()) - prob = LinearProblem(A, b1) + function fA(A) + prob = LinearProblem(A, b1) - sol1 = solve(prob, alg) + sol1 = solve(prob, alg) - sum(sol1.u) -end -fA(A) + sum(sol1.u) + end + fA(A) -fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec -@show fd_jac + fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec + @show fd_jac -en_jac = map(onehot(A)) do dA - eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA)) - eres[1] -end |> collect -@show en_jac + en_jac = map(onehot(A)) do dA + eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA)) + eres[1] + end |> collect + @show en_jac -@test en_jac ≈ fd_jac rtol=1e-6 \ No newline at end of file + @test en_jac ≈ fd_jac rtol=1e-6 +end \ No newline at end of file From 5abd3f28fe4fa5a734ec77dcb1d8384489eae487 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 7 Nov 2023 18:17:02 +0530 Subject: [PATCH 09/11] Error on unsupported return type and relax tolerence on tests to avoid random failures --- ext/LinearSolveEnzymeExt.jl | 1 + test/enzyme.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 4075f940b..2a9f031e7 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -23,6 +23,7 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, : elseif RT <: Duplicated return Duplicated(res, dres) end + error("Unsupported return type $RT") end function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} diff --git a/test/enzyme.jl b/test/enzyme.jl index 96c4d28d7..864fc5283 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -193,7 +193,7 @@ for alg in ( end |> collect @show en_jac - @test en_jac ≈ fd_jac rtol=1e-6 + @test en_jac ≈ fd_jac rtol=1e-4 function fA(A) prob = LinearProblem(A, b1) @@ -213,5 +213,5 @@ for alg in ( end |> collect @show en_jac - @test en_jac ≈ fd_jac rtol=1e-6 + @test en_jac ≈ fd_jac rtol=1e-4 end \ No newline at end of file From 42650531e683203e2525737c6fe374474e5c055c Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 7 Nov 2023 18:40:53 +0530 Subject: [PATCH 10/11] Fix bug in handling Keylov solvers --- ext/LinearSolveEnzymeExt.jl | 2 +- test/enzyme.jl | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 2a9f031e7..304e302d2 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -34,7 +34,7 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, if RT <: Const return res end - if linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod + if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") end b = deepcopy(linsolve.val.b) diff --git a/test/enzyme.jl b/test/enzyme.jl index 864fc5283..89903a858 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -173,8 +173,7 @@ for alg in ( RFLUFactorization(), # KrylovJL_GMRES(), fails ) - alg_str = string(alg) - @show alg_str + @show alg function fb(b) prob = LinearProblem(A, b) From 09359197302aa6ab607de7f189305c4c67ea1b13 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 8 Nov 2023 00:47:52 +0100 Subject: [PATCH 11/11] Update Project.toml --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 65b9d873e..e1b0e4204 100644 --- a/Project.toml +++ b/Project.toml @@ -66,7 +66,6 @@ DocStringExtensions = "0.9" EnumX = "1" EnzymeCore = "0.6" FastLapackInterface = "2" -EnzymeTestUtils = "0.1" GPUArraysCore = "0.1" HYPRE = "1.4.0" InteractiveUtils = "1.6"