Skip to content

Commit

Permalink
Refactor rule
Browse files Browse the repository at this point in the history
  • Loading branch information
sharanry committed Nov 13, 2023
1 parent c3b72c9 commit a67d7aa
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 27 deletions.
70 changes: 51 additions & 19 deletions ext/LinearSolveForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,60 @@ isdefined(Base, :get_extension) ?
(import ForwardDiff; using ForwardDiff: Dual) :
(import ..ForwardDiff; using ..ForwardDiff: Dual)

function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
@assert !(eltype(first(dAs)) isa Dual)
@assert !(eltype(first(dbs)) isa Dual)
@assert !(eltype(A) isa Dual)
@assert !(eltype(b) isa Dual)
reltol = cache.reltol isa Dual ? ForwardDiff.value(cache.reltol) : cache.reltol
abstol = cache.abstol isa Dual ? ForwardDiff.value(cache.abstol) : cache.abstol
u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u
cacheval = eltype(cache.cacheval.factors) <: Dual ? begin
LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cache.cacheval.factors), cache.cacheval.ipiv, cache.cacheval.info)
end : cache.cacheval
cache2 = remake(cache; A, b, u, reltol, abstol, cacheval)
res = LinearSolve.solve!(cache2, alg, kwargs...)
dresus = reduce(hcat, map(dAs, dbs) do dA, db
cache2.b = db - dA * res.u
dres = LinearSolve.solve!(cache2, alg, kwargs...)
deepcopy(dres.u)
end)
# display(dresus)
d = Dual{T}.(res.u, Tuple.(eachrow(dresus)))
LinearSolve.SciMLBase.build_linear_solution(alg, d, nothing, cache; retcode=res.retcode, iters=res.iters, stats=res.stats)
end

function LinearSolve.solve!(
cache::LinearSolve.LinearCache{A_,B},
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}},
alg::LinearSolve.AbstractFactorization;
kwargs...
) where {T, V, P, A_<:AbstractArray{<:Real}, B<:AbstractArray{<:Dual{T,V,P}}}
@info "using solve! from LinearSolveForwardDiff.jl"
dA = eltype(cache.A) <: Dual ? ForwardDiff.partials.(cache.A) : zero(cache.A)
db = eltype(cache.b) <: Dual ? ForwardDiff.partials.(cache.b) : zero(cache.b)
@show typeof(cache.A)
@show typeof(cache.b)
@show typeof(cache.u)
A = eltype(cache.A) <: Dual ? ForwardDiff.value.(cache.A) : cache.A
b = eltype(cache.b) <: Dual ? ForwardDiff.value.(cache.b) : cache.b
u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u
@show typeof(A), size(A)
@show typeof(b), size(b)
@show typeof(u), size(u)
cache2 = remake(cache; A, b, u)
res = LinearSolve.solve!(cache2, alg, kwargs...)
dcache = remake(cache2; b = db - dA * res.u)
dres = LinearSolve.solve!(dcache, alg, kwargs...)
LinearSolve.SciMLBase.build_linear_solution(alg, Dual{T,V,P}.(res.u, dres.u), nothing, cache)
) where {T, V, P}
@info "using solve! df/dA"
dAs = begin
dAs_ = ForwardDiff.partials.(cache.A)
dAs_ = collect.(dAs_)
dAs_ = [getindex.(dAs_, i) for i in 1:length(first(dAs_))]
end
dbs = [zero(cache.b) for _=1:P]
A = ForwardDiff.value.(cache.A)
b = cache.b
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
end
function LinearSolve.solve!(
cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}},
alg::LinearSolve.AbstractFactorization;
kwargs...
) where {T, V, P, A_}
@info "using solve! df/db"
dAs = [zero(cache.A) for _=1:P]
dbs = begin
dbs_ = ForwardDiff.partials.(cache.b)
dbs_ = collect.(dbs_)
dbs_ = [getindex.(dbs_, i) for i in 1:length(first(dbs_))]
end
A = cache.A
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
end

end # module
34 changes: 26 additions & 8 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@ using Test
using ForwardDiff
using LinearSolve
using FiniteDiff
using Enzyme
using Random
Random.seed!(1234)

n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
for alg in (
LUFactorization(),
# RFLUFactorization(),
# KrylovJL_GMRES(),
)
alg = LUFactorization()
# for alg in (
# LUFactorization(),
# # RFLUFactorization(),
# # KrylovJL_GMRES(),
# )
alg_str = string(alg)
@show alg_str
function fb(b)
Expand All @@ -26,6 +30,14 @@ for alg in (
fid_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fid_jac

# manual_jac = map(onehot(b1)) do db
# y = A \ b1
# inv(A) * (db - dA*y)
# end |> collect
# display(manual_jac)
# @show sum(manual_jac)
# @show sum.(manual_jac)

fod_jac = ForwardDiff.gradient(fb, b1) |> vec
@show fod_jac

Expand All @@ -39,13 +51,19 @@ for alg in (
sum(sol1.u)
end
fA(A)
db = zero(b1)
manual_jac = map(onehot(A)) do dA
y = A \ b1
sum(inv(A) * (db - dA*y))
end |> collect
display(reduce(hcat, manual_jac))

fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fid_jac

# @test_throws MethodError fod_jac = ForwardDiff.gradient(fA, A) |> vec
fod_jac = ForwardDiff.gradient(fA, A) |> vec
# @show fod_jac
@show fod_jac

# @test fod_jac ≈ fid_jac rtol=1e-6
end
@test fod_jac fid_jac rtol=1e-6
# end

0 comments on commit a67d7aa

Please sign in to comment.