diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 607189e10..f04d01084 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -4,11 +4,13 @@ using LinearSolve using LinearSolve.LinearAlgebra isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) - using Enzyme using EnzymeCore +@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.KrylovJL}) = true +@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.Krylov.GmresSolver}) = true + function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} res = func.val(prob.val, alg.val; kwargs...) dres = if EnzymeRules.width(config) == 1 diff --git a/test/enzyme.jl b/test/enzyme.jl index 62904c055..7196f38f2 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -107,7 +107,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), @test db1 ≈ db12 @test db2 ≈ db22 -#= + function f3(A, b1, b2; alg = KrylovJL_GMRES()) prob = LinearProblem(A, b1) cache = init(prob, alg) @@ -117,9 +117,11 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES()) norm(s1 + s2) end +dA = zeros(n, n); +db1 = zeros(n); +db2 = zeros(n); Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) @test dA ≈ dA2 atol=5e-5 @test db1 ≈ db12 -@test db2 ≈ db22 -=# \ No newline at end of file +@test db2 ≈ db22 \ No newline at end of file