From 6e8cf5de40128f13bb0f3f7786de7ea6387b9914 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 14 Nov 2023 03:29:32 +0530 Subject: [PATCH] Refactor code and add dispatch where both A and b are dual --- ext/LinearSolveForwardDiff.jl | 36 ++++++++++++++++------- test/forwarddiff.jl | 54 ++++++++++++++++++++++++++--------- 2 files changed, 66 insertions(+), 24 deletions(-) diff --git a/ext/LinearSolveForwardDiff.jl b/ext/LinearSolveForwardDiff.jl index decbaeb2f..5bf5e42e5 100644 --- a/ext/LinearSolveForwardDiff.jl +++ b/ext/LinearSolveForwardDiff.jl @@ -17,27 +17,26 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cache.cacheval.factors), cache.cacheval.ipiv, cache.cacheval.info) end : cache.cacheval cache2 = remake(cache; A, b, u, reltol, abstol, cacheval) - res = LinearSolve.solve!(cache2, alg, kwargs...) + res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy dresus = reduce(hcat, map(dAs, dbs) do dA, db cache2.b = db - dA * res.u dres = LinearSolve.solve!(cache2, alg, kwargs...) deepcopy(dres.u) end) - # display(dresus) d = Dual{T}.(res.u, Tuple.(eachrow(dresus))) LinearSolve.SciMLBase.build_linear_solution(alg, d, nothing, cache; retcode=res.retcode, iters=res.iters, stats=res.stats) end + function LinearSolve.solve!( - cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}}, + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:AbstractFloat}}, alg::LinearSolve.AbstractFactorization; kwargs... ) where {T, V, P} @info "using solve! df/dA" dAs = begin - dAs_ = ForwardDiff.partials.(cache.A) - dAs_ = collect.(dAs_) - dAs_ = [getindex.(dAs_, i) for i in 1:length(first(dAs_))] + t = collect.(ForwardDiff.partials.(cache.A)) + [getindex.(t, i) for i in 1:P] end dbs = [zero(cache.b) for _=1:P] A = ForwardDiff.value.(cache.A) @@ -45,20 +44,37 @@ function LinearSolve.solve!( _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) end function LinearSolve.solve!( - cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}}, + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:AbstractFloat},<:AbstractArray{<:Dual{T,V,P}}}, alg::LinearSolve.AbstractFactorization; kwargs... ) where {T, V, P, A_} @info "using solve! df/db" dAs = [zero(cache.A) for _=1:P] dbs = begin - dbs_ = ForwardDiff.partials.(cache.b) - dbs_ = collect.(dbs_) - dbs_ = [getindex.(dbs_, i) for i in 1:length(first(dbs_))] + t = collect.(ForwardDiff.partials.(cache.b)) + [getindex.(t, i) for i in 1:P] end A = cache.A b = ForwardDiff.value.(cache.b) _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) end +function LinearSolve.solve!( + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}}, + alg::LinearSolve.AbstractFactorization; + kwargs... + ) where {T, V, P} + @info "using solve! df/dAb" + dAs = begin + t = collect.(ForwardDiff.partials.(cache.A)) + [getindex.(t, i) for i in 1:P] + end + dbs = begin + t = collect.(ForwardDiff.partials.(cache.b)) + [getindex.(t, i) for i in 1:P] + end + A = ForwardDiff.value.(cache.A) + b = ForwardDiff.value.(cache.b) + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) +end end # module \ No newline at end of file diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index d726b87b9..1ab24ac75 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -10,12 +10,12 @@ n = 4 A = rand(n, n); dA = zeros(n, n); b1 = rand(n); -alg = LUFactorization() -# for alg in ( -# LUFactorization(), -# # RFLUFactorization(), -# # KrylovJL_GMRES(), -# ) +# alg = LUFactorization() +for alg in ( + LUFactorization(), + RFLUFactorization(), + KrylovJL_GMRES(), + ) alg_str = string(alg) @show alg_str function fb(b) @@ -51,19 +51,45 @@ alg = LUFactorization() sum(sol1.u) end fA(A) - db = zero(b1) - manual_jac = map(onehot(A)) do dA - y = A \ b1 - sum(inv(A) * (db - dA*y)) - end |> collect - display(reduce(hcat, manual_jac)) + # db = zero(b1) + # manual_jac = map(onehot(A)) do dA + # y = A \ b1 + # t = inv(A) * (db - dA*y) + # end |> collect + # display(reduce(hcat, manual_jac)) fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec @show fid_jac - # @test_throws MethodError fod_jac = ForwardDiff.gradient(fA, A) |> vec fod_jac = ForwardDiff.gradient(fA, A) |> vec @show fod_jac @test fod_jac ≈ fid_jac rtol=1e-6 -# end \ No newline at end of file + + + # function fAb(Ab) + # A = Ab[:, 1:n] + # b1 = Ab[:, n+1] + # prob = LinearProblem(A, b1) + + # sol1 = solve(prob, alg) + + # sum(sol1.u) + # end + # fAb(hcat(A, b1)) + # # db = zero(b1) + # # manual_jac = map(onehot(A)) do dA + # # y = A \ b1 + # # t = inv(A) * (db - dA*y) + # # end |> collect + # # display(reduce(hcat, manual_jac)) + + # fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec + # @show fid_jac + + # fod_jac = ForwardDiff.gradient(fAb, hcat(A, b1)) |> vec + # @show fod_jac + + # @test fod_jac ≈ fid_jac rtol=1e-6 + +end \ No newline at end of file