diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 7f0c255f0..4149ee942 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -83,13 +83,12 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line (dr.u for dr in dres) end - cache = (res, resvals) + cache = (res, resvals, deepcopy(linsolve.val)) return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) end function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} - y, dys = cache - _linsolve = linsolve.val + y, dys, _linsolve = cache @assert !(typeof(linsolve) <: Const) @assert !(typeof(linsolve) <: Active) @@ -113,9 +112,9 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s for (dA, db, dy) in zip(dAs, dbs, dys) z = if _linsolve.cacheval isa Factorization _linsolve.cacheval' \ dy - elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization + elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization _linsolve.cacheval[1]' \ dy - elseif linsolve.alg isa AbstractKrylovSubspaceMethod + elseif _linsolve.alg isa AbstractKrylovSubspaceMethod # Doesn't modify `A`, so it's safe to just reuse it invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy) solve(invprob; diff --git a/test/enzyme.jl b/test/enzyme.jl index ab651c508..1f2967913 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,4 +1,4 @@ -using Enzyme, FiniteDiff +using Enzyme, ForwardDiff using LinearSolve, LinearAlgebra, Test n = 4 @@ -20,8 +20,8 @@ f(A, b1) # Uses BLAS Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) -dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A)) -db12 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1)) +dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) @test dA ≈ dA2 @test db1 ≈ db12 @@ -35,8 +35,8 @@ db12 = zeros(n); @test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12))) -dA_2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A)) -db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1)) +dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) @test_broken dA ≈ dA_2 @test_broken dA2 ≈ dA_2 @@ -45,9 +45,8 @@ db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1)) function f(A, b1, b2; alg = LUFactorization()) prob = LinearProblem(A, b1) - cache = init(prob, alg) - s1 = solve!(cache).u + s1 = copy(solve!(cache).u) cache.b = b2 s2 = solve!(cache).u norm(s1 + s2) @@ -60,11 +59,46 @@ db1 = zeros(n); b2 = rand(n); db2 = zeros(n); +f(A, b1, b2) Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) -dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1,b2), copy(A)) -db12 = FiniteDiff.finite_difference_gradient(x->f(A,x,b2), copy(b1)) -db22 = FiniteDiff.finite_difference_gradient(x->f(A,b1,x), copy(b2)) +dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1),eltype(x).(b2)), copy(A)) +db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x,eltype(x).(b2)), copy(b1)) +db22 = ForwardDiff.gradient(x->f(eltype(x).(A),eltype(x).(b1),x), copy(b2)) + +@test dA ≈ dA2 +@test db1 ≈ db12 +@test db2 ≈ db22 + +function f2(A, b1, b2; alg = RFLUFactorization()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +f2(A, b1, b2) +dA = zeros(n, n); +db1 = zeros(n); +db2 = zeros(n); +Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) + +@test dA ≈ dA2 +@test db1 ≈ db12 +@test db2 ≈ db22 + +function f3(A, b1, b2; alg = KrylovJL_GMRES()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = solve!(cache).u + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) @test dA ≈ dA2 atol=5e-5 @test db1 ≈ db12