From bb6d6234a3d8c7bad03b36b33c89ed567b008111 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 21 Sep 2023 20:42:18 -0500 Subject: [PATCH 01/20] Add Enzyme extension --- Project.toml | 3 +++ src/init.jl | 3 +++ 2 files changed, 6 insertions(+) diff --git a/Project.toml b/Project.toml index 2cabbc4f3..f56259229 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -30,6 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" @@ -42,6 +44,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" [extensions] LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" +LinearSolveEnzymeExt = "Enzyme" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" diff --git a/src/init.jl b/src/init.jl index 2dccda626..360a2c86e 100644 --- a/src/init.jl +++ b/src/init.jl @@ -15,5 +15,8 @@ function __init__() @require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin include("../ext/LinearSolveMKLExt.jl") end + @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/LinearSolveEnzymeExt.jl") + end end end From 9f8d18fbb6a13fbb592579c7b809b1ff14acc282 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 21 Sep 2023 20:49:13 -0500 Subject: [PATCH 02/20] Add actual file --- ext/LinearSolveEnzymeExt.jl | 38 +++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 ext/LinearSolveEnzymeExt.jl diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl new file mode 100644 index 000000000..f38cf56e2 --- /dev/null +++ b/ext/LinearSolveEnzymeExt.jl @@ -0,0 +1,38 @@ +module LinearSolveEnzymeExt + +using LinearSolve +isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) + + +using Enzyme + +using EnzymeCore + +# y=inv(A) B +# dA −= z y^T +# dB += z, where z = inv(A^T) dy +function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem} + res = func.val(prob.val, alg.val; kwargs...) + dres = deepcopy(res) + dres.u .= 0 + cache = (copy(prob.val.A), res, dres.u) + return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, typeof(cache)}(res, dres, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, cache, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem} + A, y, dy = cache + + dA = prob.dval.A + db = prob.dval.b + + invprob = LinearProblem(transpose(A), dy) + + z = func.val(invprob, alg; kwargs...) + + dA .-= z * transpose(y) + db .+= z + dy .= 0 + return (nothing, nothing) +end + +end \ No newline at end of file From 391b6029d96bfe0d1c49185212e81ce61b918501 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 22 Sep 2023 10:18:44 -0400 Subject: [PATCH 03/20] fix typo --- test/basictests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basictests.jl b/test/basictests.jl index 42f283173..888f6322e 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -202,7 +202,7 @@ end end end - test_algs = if VERISON >= v"1.9" + test_algs = if VERSION >= v"1.9" (LUFactorization(), QRFactorization(), SVDFactorization(), From ce7ffc0b9db0fa423464f680ee918cf18fffc4f3 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 22 Sep 2023 10:39:47 -0400 Subject: [PATCH 04/20] more v1.9 --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 036bcf97e..a28cc0e7f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,7 @@ const HAS_EXTENSIONS = isdefined(Base, :get_extension) if GROUP == "All" || GROUP == "Core" @time @safetestset "Basic Tests" include("basictests.jl") - @time @safetestset "Re-solve" include("resolve.jl") + VERSION >= v"1.9" && @time @safetestset "Re-solve" include("resolve.jl") @time @safetestset "Zero Initialization Tests" include("zeroinittests.jl") @time @safetestset "Non-Square Tests" include("nonsquare.jl") @time @safetestset "SparseVector b Tests" include("sparse_vector.jl") From a08386deee59135da2b973c4228cc2e0be9379fb Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 22 Sep 2023 11:05:44 -0400 Subject: [PATCH 05/20] add a test for Enzyme rule correctness --- Project.toml | 4 +++- test/enzyme.jl | 30 ++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 test/enzyme.jl diff --git a/Project.toml b/Project.toml index f56259229..a9c2f44b2 100644 --- a/Project.toml +++ b/Project.toml @@ -81,6 +81,8 @@ julia = "1.6" [extras] BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -98,4 +100,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"] +test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"] diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 000000000..8f6d213c0 --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,30 @@ +using Enzyme, FiniteDiff +using LinearSolve, LinearAlgebra, Test + +n = 4 +A = rand(n, n); +dA = zeros(n, n); +b1 = rand(n); +db1 = zeros(n); +b2 = rand(n); +db2 = zeros(n); + +function f(A, b1, b2; alg = LUFactorization()) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + s1 = sol1.u + norm(s1) +end + +f(A, b1, b2) # Uses BLAS + +Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) + +dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1, b2), copy(A)) +db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1)) + +@test dA ≈ dA2 +@test db1 ≈ db12 +@test db2 == zeros(4) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index a28cc0e7f..4f2e78feb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ if GROUP == "All" || GROUP == "Core" @time @safetestset "Non-Square Tests" include("nonsquare.jl") @time @safetestset "SparseVector b Tests" include("sparse_vector.jl") @time @safetestset "Default Alg Tests" include("default_algs.jl") + VERSION >= v"1.9" && @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") @time @safetestset "Traits" include("traits.jl") end From 9273a2089f777b0e72b8fbefeb658d18d6609d94 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 22 Sep 2023 16:34:45 -0500 Subject: [PATCH 06/20] Extend --- ext/LinearSolveEnzymeExt.jl | 85 ++++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 16 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index f38cf56e2..b87f45c85 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -8,31 +8,84 @@ using Enzyme using EnzymeCore +function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} + res = func.val(prob.val, alg.val; kwargs...) + dres = if EnzymeRules.width(config) == 1 + func.val(prob.dval, alg.val; kwargs...) + else + (func.val(dval, alg.val; kwargs...) for dval in prob.dval) + end + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, nothing) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} + return (nothing, nothing) +end + # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem} - res = func.val(prob.val, alg.val; kwargs...) - dres = deepcopy(res) - dres.u .= 0 - cache = (copy(prob.val.A), res, dres.u) - return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, typeof(cache)}(res, dres, cache) +function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} + res = func.val(linsolve.val; kwargs...) + dres = if EnzymeRules.width(config) == 1 + deepcopy(res) + else + (deepcopy(res) for dval in linsolve.dval) + end + + if EnzymeRules.width(config) == 1 + dres.u .= 0 + else + for dr in dres + dr.u .= 0 + end + end + + resvals = if EnzymeRules.width(config) == 1 + dres.u + else + (dr.u for dr in dres) + end + + cache = (copy(linsolve.val.A), res, resvals) + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) end -function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, cache, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem} - A, y, dy = cache +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} + A, y, dys = cache - dA = prob.dval.A - db = prob.dval.b + @assert !(typeof(linsolve) <: Const) + @assert !(typeof(linsolve) <: Active) - invprob = LinearProblem(transpose(A), dy) + if EnzymeRules.width(config) == 1 + dys = (dys,) + end - z = func.val(invprob, alg; kwargs...) + dAs = if EnzymeRules.width(config) == 1 + (linsolve.dval.A,) + else + (dval.A for dval in linsolve.dval) + end - dA .-= z * transpose(y) - db .+= z - dy .= 0 - return (nothing, nothing) + dbs = if EnzymeRules.width(config) == 1 + (linsolve.dval.b,) + else + (dval.b for dval in linsolve.dval) + end + + for (dA, db, dy) in zip(dAs, dbs, dys) + invprob = LinearSolve.LinearProblem(transpose(A), dy) + z = solve(invprob; + abstol = linsolve.val.abstol, + reltol = linsolve.val.reltol, + verbose = linsolve.val.verbose) + + dA .-= z * transpose(y) + db .+= z + dy .= eltype(dy)(0) + end + + return (nothing,) end end \ No newline at end of file From 84c5196d9c49ada2ac69f34e702300d7ce1d498e Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 22 Sep 2023 19:40:44 -0400 Subject: [PATCH 07/20] add some batch tests --- test/enzyme.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index 8f6d213c0..dbacb70f1 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -27,4 +27,18 @@ db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1)) @test dA ≈ dA2 @test db1 ≈ db12 -@test db2 == zeros(4) \ No newline at end of file +@test db2 == zeros(4) + +A = rand(n, n); +dA = zeros(n, n); +dA2 = zeros(n, n); +b1 = rand(n); +db1 = zeros(n); +db12 = zeros(n); + +b2 = rand(n); +db2 = zeros(n); +db22 = zeros(n); + +@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12)), BatchDuplicated(copy(b2), (db2, db22))) +@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) \ No newline at end of file From bb93d68523b3111074e352e2fa15306739f4730d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 22 Sep 2023 23:41:49 -0500 Subject: [PATCH 08/20] Fix --- ext/LinearSolveEnzymeExt.jl | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index b87f45c85..53cc7e43a 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -15,10 +15,44 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line else (func.val(dval, alg.val; kwargs...) for dval in prob.dval) end - return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, nothing) + d_A = if EnzymeRules.width(config) == 1 + dres.A + else + (dval.A for dval in dres) + end + d_b = if EnzymeRules.width(config) == 1 + dres.b + else + (dval.b for dval in dres) + end + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b)) end function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} + d_A, d_b = cache + + if EnzymeRules.width(config) == 1 + if d_A !== prob.dval.A + prob.dval.A .+= d_A + d_A .= 0 + end + if d_b !== prob.dval.b + prob.dval.b .+= d_b + d_b .= 0 + end + else + for i in 1:EnzymeRules.width(config) + if d_A !== prob.dval.A + prob.dval.A[i] .+= d_A[i] + d_A[i] .= 0 + end + if d_b !== prob.dval.b + prob.dval.b[i] .+= d_b[i] + d_b[i] .= 0 + end + end + end + return (nothing, nothing) end From 9d19db22010178e4a5d77df5149847a802d94466 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 23 Sep 2023 14:51:26 -0500 Subject: [PATCH 09/20] Cache before LU in place --- ext/LinearSolveEnzymeExt.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 53cc7e43a..8738bf141 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -60,7 +60,11 @@ end # dA −= z y^T # dB += z, where z = inv(A^T) dy function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} + @assert linsolve.val.isfresh + A_cache = copy(linsolve.val.A) + res = func.val(linsolve.val; kwargs...) + dres = if EnzymeRules.width(config) == 1 deepcopy(res) else @@ -81,7 +85,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line (dr.u for dr in dres) end - cache = (copy(linsolve.val.A), res, resvals) + cache = (A_cache, res, resvals) return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) end From f9b078469cf0efd08e37b6cdd2e6caed8eff85bc Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 23 Sep 2023 14:54:11 -0500 Subject: [PATCH 10/20] simplify test --- test/enzyme.jl | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index dbacb70f1..e00553d69 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -6,10 +6,8 @@ A = rand(n, n); dA = zeros(n, n); b1 = rand(n); db1 = zeros(n); -b2 = rand(n); -db2 = zeros(n); -function f(A, b1, b2; alg = LUFactorization()) +function f(A, b1; alg = LUFactorization()) prob = LinearProblem(A, b1) sol1 = solve(prob, alg) @@ -18,16 +16,15 @@ function f(A, b1, b2; alg = LUFactorization()) norm(s1) end -f(A, b1, b2) # Uses BLAS +f(A, b1) # Uses BLAS -Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) +Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1, b2), copy(A)) db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1)) @test dA ≈ dA2 @test db1 ≈ db12 -@test db2 == zeros(4) A = rand(n, n); dA = zeros(n, n); @@ -36,9 +33,6 @@ b1 = rand(n); db1 = zeros(n); db12 = zeros(n); -b2 = rand(n); -db2 = zeros(n); -db22 = zeros(n); - -@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12)), BatchDuplicated(copy(b2), (db2, db22))) -@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) \ No newline at end of file +# This is not legal, all args need to be batch'd at the same size +@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12))) +@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1)) From 3b39753c0606d7e1c0962ef1e1e47f0ebae33cf5 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 23 Sep 2023 22:29:40 -0400 Subject: [PATCH 11/20] fix multiple solve handling --- ext/LinearSolveEnzymeExt.jl | 28 +++++++++++++++--------- test/enzyme.jl | 43 ++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 8738bf141..45e9231d1 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -1,6 +1,7 @@ module LinearSolveEnzymeExt using LinearSolve +using LinearSolve.LinearAlgebra isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) @@ -60,9 +61,6 @@ end # dA −= z y^T # dB += z, where z = inv(A^T) dy function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} - @assert linsolve.val.isfresh - A_cache = copy(linsolve.val.A) - res = func.val(linsolve.val; kwargs...) dres = if EnzymeRules.width(config) == 1 @@ -85,12 +83,12 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line (dr.u for dr in dres) end - cache = (A_cache, res, resvals) + cache = (res, resvals, linsolve.val) return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) end function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} - A, y, dys = cache + y, dys, _linsolve = cache @assert !(typeof(linsolve) <: Const) @assert !(typeof(linsolve) <: Active) @@ -112,11 +110,21 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s end for (dA, db, dy) in zip(dAs, dbs, dys) - invprob = LinearSolve.LinearProblem(transpose(A), dy) - z = solve(invprob; - abstol = linsolve.val.abstol, - reltol = linsolve.val.reltol, - verbose = linsolve.val.verbose) + z = if linsolve.cacheval isa Factorization + linsolve.cacheval' \ dy + elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization + linsolve.cacheval[1]' \ dy + elseif linsolve.alg isa AbstractKrylovSubspaceMethod + # Doesn't modify `A`, so it's safe to just reuse it + invprob = LinearSolve.LinearProblem(transpose(linsolve.A), dy) + solve(invprob; + abstol = linsolve.val.abstol, + reltol = linsolve.val.reltol, + verbose = linsolve.val.verbose, + isfresh = freshbefore) + else + error("Algorithm $(linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") + end dA .-= z * transpose(y) db .+= z diff --git a/test/enzyme.jl b/test/enzyme.jl index e00553d69..ab651c508 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -20,8 +20,8 @@ f(A, b1) # Uses BLAS Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) -dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1, b2), copy(A)) -db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1)) +dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A)) +db12 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1)) @test dA ≈ dA2 @test db1 ≈ db12 @@ -33,6 +33,39 @@ b1 = rand(n); db1 = zeros(n); db12 = zeros(n); -# This is not legal, all args need to be batch'd at the same size -@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12))) -@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1)) +@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12))) + +dA_2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A)) +db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1)) + +@test_broken dA ≈ dA_2 +@test_broken dA2 ≈ dA_2 +@test_broken db1 ≈ db1_2 +@test_broken db12 ≈ db1_2 + +function f(A, b1, b2; alg = LUFactorization()) + prob = LinearProblem(A, b1) + + cache = init(prob, alg) + s1 = solve!(cache).u + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +A = rand(n, n); +dA = zeros(n, n); +b1 = rand(n); +db1 = zeros(n); +b2 = rand(n); +db2 = zeros(n); + +Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) + +dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1,b2), copy(A)) +db12 = FiniteDiff.finite_difference_gradient(x->f(A,x,b2), copy(b1)) +db22 = FiniteDiff.finite_difference_gradient(x->f(A,b1,x), copy(b2)) + +@test dA ≈ dA2 atol=5e-5 +@test db1 ≈ db12 +@test db2 ≈ db22 \ No newline at end of file From cbb5f1d1b1bb57c6f8d594670dc552f72de4c5fe Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 23 Sep 2023 22:30:51 -0400 Subject: [PATCH 12/20] fix multiple solve handling --- ext/LinearSolveEnzymeExt.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 45e9231d1..c3e474c71 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -83,12 +83,13 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line (dr.u for dr in dres) end - cache = (res, resvals, linsolve.val) + cache = (res, resvals) return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) end function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} - y, dys, _linsolve = cache + y, dys = cache + _linsolve = linsolve.val @assert !(typeof(linsolve) <: Const) @assert !(typeof(linsolve) <: Active) @@ -110,8 +111,8 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s end for (dA, db, dy) in zip(dAs, dbs, dys) - z = if linsolve.cacheval isa Factorization - linsolve.cacheval' \ dy + z = if _linsolve.cacheval isa Factorization + _linsolve.cacheval' \ dy elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization linsolve.cacheval[1]' \ dy elseif linsolve.alg isa AbstractKrylovSubspaceMethod From 9630121f955d1e0d32b14e28196a5ff2e849168a Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 23 Sep 2023 22:32:14 -0400 Subject: [PATCH 13/20] fix other algorithms --- ext/LinearSolveEnzymeExt.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index c3e474c71..7f0c255f0 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -114,17 +114,16 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s z = if _linsolve.cacheval isa Factorization _linsolve.cacheval' \ dy elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization - linsolve.cacheval[1]' \ dy + _linsolve.cacheval[1]' \ dy elseif linsolve.alg isa AbstractKrylovSubspaceMethod # Doesn't modify `A`, so it's safe to just reuse it - invprob = LinearSolve.LinearProblem(transpose(linsolve.A), dy) + invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy) solve(invprob; - abstol = linsolve.val.abstol, - reltol = linsolve.val.reltol, - verbose = linsolve.val.verbose, - isfresh = freshbefore) + abstol = _linsolve.val.abstol, + reltol = _linsolve.val.reltol, + verbose = _linsolve.val.verbose) else - error("Algorithm $(linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") + error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") end dA .-= z * transpose(y) From b0d228d43def8916caa1f1704c603b63ac812a76 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 23 Sep 2023 23:00:05 -0400 Subject: [PATCH 14/20] getting very close --- ext/LinearSolveEnzymeExt.jl | 9 +++---- test/enzyme.jl | 54 ++++++++++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 7f0c255f0..4149ee942 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -83,13 +83,12 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line (dr.u for dr in dres) end - cache = (res, resvals) + cache = (res, resvals, deepcopy(linsolve.val)) return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) end function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} - y, dys = cache - _linsolve = linsolve.val + y, dys, _linsolve = cache @assert !(typeof(linsolve) <: Const) @assert !(typeof(linsolve) <: Active) @@ -113,9 +112,9 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s for (dA, db, dy) in zip(dAs, dbs, dys) z = if _linsolve.cacheval isa Factorization _linsolve.cacheval' \ dy - elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization + elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization _linsolve.cacheval[1]' \ dy - elseif linsolve.alg isa AbstractKrylovSubspaceMethod + elseif _linsolve.alg isa AbstractKrylovSubspaceMethod # Doesn't modify `A`, so it's safe to just reuse it invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy) solve(invprob; diff --git a/test/enzyme.jl b/test/enzyme.jl index ab651c508..1f2967913 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,4 +1,4 @@ -using Enzyme, FiniteDiff +using Enzyme, ForwardDiff using LinearSolve, LinearAlgebra, Test n = 4 @@ -20,8 +20,8 @@ f(A, b1) # Uses BLAS Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) -dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A)) -db12 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1)) +dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) @test dA ≈ dA2 @test db1 ≈ db12 @@ -35,8 +35,8 @@ db12 = zeros(n); @test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12))) -dA_2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A)) -db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1)) +dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) @test_broken dA ≈ dA_2 @test_broken dA2 ≈ dA_2 @@ -45,9 +45,8 @@ db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1)) function f(A, b1, b2; alg = LUFactorization()) prob = LinearProblem(A, b1) - cache = init(prob, alg) - s1 = solve!(cache).u + s1 = copy(solve!(cache).u) cache.b = b2 s2 = solve!(cache).u norm(s1 + s2) @@ -60,11 +59,46 @@ db1 = zeros(n); b2 = rand(n); db2 = zeros(n); +f(A, b1, b2) Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) -dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1,b2), copy(A)) -db12 = FiniteDiff.finite_difference_gradient(x->f(A,x,b2), copy(b1)) -db22 = FiniteDiff.finite_difference_gradient(x->f(A,b1,x), copy(b2)) +dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1),eltype(x).(b2)), copy(A)) +db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x,eltype(x).(b2)), copy(b1)) +db22 = ForwardDiff.gradient(x->f(eltype(x).(A),eltype(x).(b1),x), copy(b2)) + +@test dA ≈ dA2 +@test db1 ≈ db12 +@test db2 ≈ db22 + +function f2(A, b1, b2; alg = RFLUFactorization()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +f2(A, b1, b2) +dA = zeros(n, n); +db1 = zeros(n); +db2 = zeros(n); +Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) + +@test dA ≈ dA2 +@test db1 ≈ db12 +@test db2 ≈ db22 + +function f3(A, b1, b2; alg = KrylovJL_GMRES()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = solve!(cache).u + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) @test dA ≈ dA2 atol=5e-5 @test db1 ≈ db12 From c2ad2dbc7f8b6a7820c7f2d5d82d876addf6bf57 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 23 Sep 2023 23:06:33 -0400 Subject: [PATCH 15/20] push batch test updates --- test/enzyme.jl | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index 1f2967913..46dbba438 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -33,7 +33,20 @@ b1 = rand(n); db1 = zeros(n); db12 = zeros(n); -@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12))) +function fbatch(y, A, b1; alg = LUFactorization()) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + s1 = sol1.u + y[1] = norm(s1) + nothing +end + +y = [0.0] +dy1 = [1.0] +dy2 = [1.0] +Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12))) dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) @@ -92,7 +105,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), function f3(A, b1, b2; alg = KrylovJL_GMRES()) prob = LinearProblem(A, b1) cache = init(prob, alg) - s1 = solve!(cache).u + s1 = copy(solve!(cache).u) cache.b = b2 s2 = solve!(cache).u norm(s1 + s2) From 54f0722360ac07231dadca7d1edd9e541b972bf0 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 23 Sep 2023 22:12:52 -0500 Subject: [PATCH 16/20] type stable --- ext/LinearSolveEnzymeExt.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 4149ee942..b4bddf371 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -14,7 +14,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line dres = if EnzymeRules.width(config) == 1 func.val(prob.dval, alg.val; kwargs...) else - (func.val(dval, alg.val; kwargs...) for dval in prob.dval) + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + func.val(prob.dval[i], alg.val; kwargs...) + end end d_A = if EnzymeRules.width(config) == 1 dres.A @@ -66,7 +69,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line dres = if EnzymeRules.width(config) == 1 deepcopy(res) else - (deepcopy(res) for dval in linsolve.dval) + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + deepcopy(res) + end end if EnzymeRules.width(config) == 1 From e4f0785cf8cbc7cfb5d56184cbd0099dfd596005 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 23 Sep 2023 23:26:18 -0500 Subject: [PATCH 17/20] fix mutated db --- ext/LinearSolveEnzymeExt.jl | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index b4bddf371..9db9a1770 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -29,6 +29,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line else (dval.b for dval in dres) end + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b)) end @@ -89,20 +90,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line (dr.u for dr in dres) end - cache = (res, resvals, deepcopy(linsolve.val)) - return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} - y, dys, _linsolve = cache - - @assert !(typeof(linsolve) <: Const) - @assert !(typeof(linsolve) <: Active) - - if EnzymeRules.width(config) == 1 - dys = (dys,) - end - dAs = if EnzymeRules.width(config) == 1 (linsolve.dval.A,) else @@ -115,6 +102,20 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s (dval.b for dval in linsolve.dval) end + cache = (res, resvals, deepcopy(linsolve.val), dAs, dbs) + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} + y, dys, _linsolve, dAs, dbs = cache + + @assert !(typeof(linsolve) <: Const) + @assert !(typeof(linsolve) <: Active) + + if EnzymeRules.width(config) == 1 + dys = (dys,) + end + for (dA, db, dy) in zip(dAs, dbs, dys) z = if _linsolve.cacheval isa Factorization _linsolve.cacheval' \ dy From d69af77971e639b47ff34bba1849efa949842a67 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 03:39:46 -0500 Subject: [PATCH 18/20] More caching --- ext/LinearSolveEnzymeExt.jl | 51 +++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 9db9a1770..607189e10 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -30,29 +30,41 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line (dval.b for dval in dres) end - return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b)) + + prob_d_A = if EnzymeRules.width(config) == 1 + prob.dval.A + else + (dval.A for dval in prob.dval) + end + prob_d_b = if EnzymeRules.width(config) == 1 + prob.dval.b + else + (dval.b for dval in prob.dval) + end + + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b)) end function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} - d_A, d_b = cache + d_A, d_b, prob_d_A, prob_d_b = cache if EnzymeRules.width(config) == 1 - if d_A !== prob.dval.A - prob.dval.A .+= d_A + if d_A !== prob_d_A + prob_d_A .+= d_A d_A .= 0 end - if d_b !== prob.dval.b - prob.dval.b .+= d_b + if d_b !== prob_d_b + prob_d_b .+= d_b d_b .= 0 end else for i in 1:EnzymeRules.width(config) - if d_A !== prob.dval.A - prob.dval.A[i] .+= d_A[i] + if d_A !== prob_d_A[i] + prob_d_A[i] .+= d_A[i] d_A[i] .= 0 end - if d_b !== prob.dval.b - prob.dval.b[i] .+= d_b[i] + if d_b !== prob_d_b[i] + prob_d_b[i] .+= d_b[i] d_b[i] .= 0 end end @@ -87,22 +99,33 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line resvals = if EnzymeRules.width(config) == 1 dres.u else - (dr.u for dr in dres) + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + dres[i].u + end end dAs = if EnzymeRules.width(config) == 1 (linsolve.dval.A,) else - (dval.A for dval in linsolve.dval) + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + linsolve.dval[i].A + end end dbs = if EnzymeRules.width(config) == 1 (linsolve.dval.b,) else - (dval.b for dval in linsolve.dval) + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + linsolve.dval[i].b + end end - cache = (res, resvals, deepcopy(linsolve.val), dAs, dbs) + cachesolve = deepcopy(linsolve.val) + + cache = (copy(res.u), resvals, cachesolve, dAs, dbs) return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) end From be91ba2fb0011e62381d02288d6b4c44a32e8105 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 24 Sep 2023 08:24:34 -0400 Subject: [PATCH 19/20] Remove batch test --- test/enzyme.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/enzyme.jl b/test/enzyme.jl index 46dbba438..1e72b0678 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -33,6 +33,10 @@ b1 = rand(n); db1 = zeros(n); db12 = zeros(n); +#= +# Batch test fails +# Captured in MWE: https://github.com/EnzymeAD/Enzyme.jl/issues/1075 + function fbatch(y, A, b1; alg = LUFactorization()) prob = LinearProblem(A, b1) @@ -55,6 +59,7 @@ db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) @test_broken dA2 ≈ dA_2 @test_broken db1 ≈ db1_2 @test_broken db12 ≈ db1_2 +=# function f(A, b1, b2; alg = LUFactorization()) prob = LinearProblem(A, b1) From 89e10dfccd576493f97985e9aab4d518ad992645 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 24 Sep 2023 08:40:15 -0400 Subject: [PATCH 20/20] Remove Krylov test for now --- test/enzyme.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index 1e72b0678..62904c055 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -107,6 +107,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), @test db1 ≈ db12 @test db2 ≈ db22 +#= function f3(A, b1, b2; alg = KrylovJL_GMRES()) prob = LinearProblem(A, b1) cache = init(prob, alg) @@ -120,4 +121,5 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), @test dA ≈ dA2 atol=5e-5 @test db1 ≈ db12 -@test db2 ≈ db22 \ No newline at end of file +@test db2 ≈ db22 +=# \ No newline at end of file