Skip to content

Commit

Permalink
Add ForwardDiff rule for solve!
Browse files Browse the repository at this point in the history
  • Loading branch information
sharanry committed Nov 13, 2023
1 parent 9aaf9b3 commit c3b72c9
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand All @@ -48,6 +49,7 @@ LinearSolveBandedMatricesExt = "BandedMatrices"
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = "Enzyme"
LinearSolveForwardDiff = "ForwardDiff"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand All @@ -66,6 +68,7 @@ DocStringExtensions = "0.9"
EnumX = "1"
EnzymeCore = "0.6"
FastLapackInterface = "2"
ForwardDiff = "0.10"
GPUArraysCore = "0.1"
HYPRE = "1.4.0"
InteractiveUtils = "1.6"
Expand Down
32 changes: 32 additions & 0 deletions ext/LinearSolveForwardDiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module LinearSolveForwardDiff

using LinearSolve
isdefined(Base, :get_extension) ?
(import ForwardDiff; using ForwardDiff: Dual) :
(import ..ForwardDiff; using ..ForwardDiff: Dual)

function LinearSolve.solve!(
cache::LinearSolve.LinearCache{A_,B},
alg::LinearSolve.AbstractFactorization;
kwargs...
) where {T, V, P, A_<:AbstractArray{<:Real}, B<:AbstractArray{<:Dual{T,V,P}}}
@info "using solve! from LinearSolveForwardDiff.jl"
dA = eltype(cache.A) <: Dual ? ForwardDiff.partials.(cache.A) : zero(cache.A)
db = eltype(cache.b) <: Dual ? ForwardDiff.partials.(cache.b) : zero(cache.b)
@show typeof(cache.A)
@show typeof(cache.b)
@show typeof(cache.u)
A = eltype(cache.A) <: Dual ? ForwardDiff.value.(cache.A) : cache.A
b = eltype(cache.b) <: Dual ? ForwardDiff.value.(cache.b) : cache.b
u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u
@show typeof(A), size(A)
@show typeof(b), size(b)
@show typeof(u), size(u)
cache2 = remake(cache; A, b, u)
res = LinearSolve.solve!(cache2, alg, kwargs...)
dcache = remake(cache2; b = db - dA * res.u)
dres = LinearSolve.solve!(dcache, alg, kwargs...)
LinearSolve.SciMLBase.build_linear_solution(alg, Dual{T,V,P}.(res.u, dres.u), nothing, cache)
end

end # module
9 changes: 9 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
assumptions::OperatorAssumptions{issq}
end

function SciMLBase.remake(cache::LinearCache;
A::TA=cache.A, b::TB=cache.b, u::TU=cache.u, p::TP=cache.p, alg::Talg=cache.alg,
cacheval::Tc=cache.cacheval, isfresh::Bool=cache.isfresh, Pl::Tl=cache.Pl, Pr::Tr=cache.Pr,
abstol::Ttol=cache.abstol, reltol::Ttol=cache.reltol, maxiters::Int=cache.maxiters,
verbose::Bool=cache.verbose, assumptions::OperatorAssumptions{issq}=cache.assumptions) where {TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}
LinearCache{TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}(A,b,u,p,alg,cacheval,isfresh,Pl,Pr,abstol,reltol,
maxiters,verbose,assumptions)
end

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
if name === :A
setfield!(cache, :isfresh, true)
Expand Down
51 changes: 51 additions & 0 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using Test
using ForwardDiff
using LinearSolve
using FiniteDiff

n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
for alg in (
LUFactorization(),
# RFLUFactorization(),
# KrylovJL_GMRES(),
)
alg_str = string(alg)
@show alg_str
function fb(b)
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fb(b1)

fid_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fid_jac

fod_jac = ForwardDiff.gradient(fb, b1) |> vec
@show fod_jac

@test fod_jac fid_jac rtol=1e-6

function fA(A)
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fA(A)

fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fid_jac

# @test_throws MethodError fod_jac = ForwardDiff.gradient(fA, A) |> vec
fod_jac = ForwardDiff.gradient(fA, A) |> vec
# @show fod_jac

# @test fod_jac ≈ fid_jac rtol=1e-6
end

0 comments on commit c3b72c9

Please sign in to comment.