diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index b87f45c85..53cc7e43a 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -15,10 +15,44 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line else (func.val(dval, alg.val; kwargs...) for dval in prob.dval) end - return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, nothing) + d_A = if EnzymeRules.width(config) == 1 + dres.A + else + (dval.A for dval in dres) + end + d_b = if EnzymeRules.width(config) == 1 + dres.b + else + (dval.b for dval in dres) + end + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b)) end function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} + d_A, d_b = cache + + if EnzymeRules.width(config) == 1 + if d_A !== prob.dval.A + prob.dval.A .+= d_A + d_A .= 0 + end + if d_b !== prob.dval.b + prob.dval.b .+= d_b + d_b .= 0 + end + else + for i in 1:EnzymeRules.width(config) + if d_A !== prob.dval.A + prob.dval.A[i] .+= d_A[i] + d_A[i] .= 0 + end + if d_b !== prob.dval.b + prob.dval.b[i] .+= d_b[i] + d_b[i] .= 0 + end + end + end + return (nothing, nothing) end