From 6b44b99bcfebd857c100e4f37c0fe0fce3c80d85 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Thu, 25 Apr 2024 14:45:57 -0500 Subject: [PATCH] Preconditioner support --- src/DiffKrylov.jl | 1 + src/EnzymeRules/enzymerules.jl | 46 +++++++++++++++++++++++----------- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/src/DiffKrylov.jl b/src/DiffKrylov.jl index bef9a09..9ff750a 100644 --- a/src/DiffKrylov.jl +++ b/src/DiffKrylov.jl @@ -2,6 +2,7 @@ module DiffKrylov using Krylov using SparseArrays +using LinearAlgebra include("ForwardDiff/forwarddiff.jl") include("EnzymeRules/enzymerules.jl") end diff --git a/src/EnzymeRules/enzymerules.jl b/src/EnzymeRules/enzymerules.jl index d7bd627..bddb9fb 100644 --- a/src/EnzymeRules/enzymerules.jl +++ b/src/EnzymeRules/enzymerules.jl @@ -11,28 +11,33 @@ for AMT in (:Matrix, :SparseMatrixCSC) func::Const{typeof(Krylov.$solver)}, ret::Type{RT}, _A::Annotation{MT}, - _b::Annotation{VT}, + _b::Annotation{VT}; + verbose = 0, + M = I, + N = I, options... ) where {RT, MT <: $AMT, VT <: Vector} psolver = $solver pamt = $AMT - # println("($psolver, $pamt) forward rule") + if verbose > 0 + @info "($psolver, $pamt) forward rule" + end A = _A.val b = _b.val dx = [] - x, stats = Krylov.$solver(A,b; options...) + x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...) if isa(_A, Duplicated) && isa(_b, Duplicated) dA = _A.dval db = _b.dval db -= dA*x - dx, dstats = Krylov.$solver(A,db; options...) + dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...) elseif isa(_A, Duplicated) && isa(_b, Const) dA = _A.dval db = -dA*x - dx, dstats = Krylov.$solver(A,db; options...) + dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...) elseif isa(_A, Const) && isa(_b, Duplicated) db = _b.dval - dx, dstats = Krylov.$solver(A,db; options...) + dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...) elseif isa(_A, Const) && isa(_b, Const) nothing else @@ -60,18 +65,28 @@ for AMT in (:Matrix, :SparseMatrixCSC) func::Const{typeof(Krylov.$solver)}, ret::Type{RT}, _A::Annotation{MT}, - _b::Annotation{VT} + _b::Annotation{VT}; + M=I, + N=I, + verbose=0, + options... ) where {RT, MT <: $AMT, VT <: Vector} psolver = $solver pamt = $AMT - # println("($psolver, $pamt) augmented forward") + if verbose > 0 + @info "($psolver, $pamt) augmented forward" + end A = _A.val b = _b.val - x, stats = Krylov.$solver(A,b) + x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...) bx = zeros(length(x)) bstats = deepcopy(stats) if needs_primal(config) - return AugmentedReturn((x, stats), (bx, bstats), (A,x, Ref(bx))) + return AugmentedReturn( + (x, stats), + (bx, bstats), + (A,x, Ref(bx), verbose, M, N) + ) else return AugmentedReturn(nothing, (bx, bstats), (A,x)) end @@ -83,13 +98,16 @@ for AMT in (:Matrix, :SparseMatrixCSC) dret::Type{RT}, cache, _A::Annotation{MT}, - _b::Annotation{<:Vector}, + _b::Annotation{<:Vector}; + options... ) where {RT, MT <: $AMT} + (A,x,bx,verbose,M,N) = cache psolver = $solver pamt = $AMT - # println("($psolver, $pamt) reverse") - (A,x,bx) = cache - _b.dval .= $solver(transpose(A), bx[])[1] + if verbose > 0 + @info "($psolver, $pamt) reverse" + end + _b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, N=N, verbose=verbose, options...)[1] _A.dval .= -x .* _b.dval' return (nothing, nothing) end