Skip to content

Commit

Permalink
Preconditioner support
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Apr 25, 2024
1 parent 3ca4968 commit 6b44b99
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
1 change: 1 addition & 0 deletions src/DiffKrylov.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module DiffKrylov

using Krylov
using SparseArrays
using LinearAlgebra
include("ForwardDiff/forwarddiff.jl")
include("EnzymeRules/enzymerules.jl")
end
46 changes: 32 additions & 14 deletions src/EnzymeRules/enzymerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,33 @@ for AMT in (:Matrix, :SparseMatrixCSC)
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT},
_b::Annotation{VT};
verbose = 0,
M = I,
N = I,
options...
) where {RT, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
# println("($psolver, $pamt) forward rule")
if verbose > 0
@info "($psolver, $pamt) forward rule"
end
A = _A.val
b = _b.val
dx = []
x, stats = Krylov.$solver(A,b; options...)
x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...)
if isa(_A, Duplicated) && isa(_b, Duplicated)
dA = _A.dval
db = _b.dval
db -= dA*x
dx, dstats = Krylov.$solver(A,db; options...)
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Duplicated) && isa(_b, Const)
dA = _A.dval
db = -dA*x
dx, dstats = Krylov.$solver(A,db; options...)
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Duplicated)
db = _b.dval
dx, dstats = Krylov.$solver(A,db; options...)
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Const)
nothing
else
Expand Down Expand Up @@ -60,18 +65,28 @@ for AMT in (:Matrix, :SparseMatrixCSC)
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT}
_b::Annotation{VT};
M=I,
N=I,
verbose=0,
options...
) where {RT, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
# println("($psolver, $pamt) augmented forward")
if verbose > 0
@info "($psolver, $pamt) augmented forward"
end
A = _A.val
b = _b.val
x, stats = Krylov.$solver(A,b)
x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...)
bx = zeros(length(x))
bstats = deepcopy(stats)
if needs_primal(config)
return AugmentedReturn((x, stats), (bx, bstats), (A,x, Ref(bx)))
return AugmentedReturn(
(x, stats),
(bx, bstats),
(A,x, Ref(bx), verbose, M, N)
)
else
return AugmentedReturn(nothing, (bx, bstats), (A,x))
end
Expand All @@ -83,13 +98,16 @@ for AMT in (:Matrix, :SparseMatrixCSC)
dret::Type{RT},
cache,
_A::Annotation{MT},
_b::Annotation{<:Vector},
_b::Annotation{<:Vector};
options...
) where {RT, MT <: $AMT}
(A,x,bx,verbose,M,N) = cache
psolver = $solver
pamt = $AMT
# println("($psolver, $pamt) reverse")
(A,x,bx) = cache
_b.dval .= $solver(transpose(A), bx[])[1]
if verbose > 0
@info "($psolver, $pamt) reverse"
end
_b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, N=N, verbose=verbose, options...)[1]
_A.dval .= -x .* _b.dval'
return (nothing, nothing)
end
Expand Down

0 comments on commit 6b44b99

Please sign in to comment.