From c3b72c96af713f3eaf7f270c1478388049666c49 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 8 Nov 2023 03:13:46 +0530 Subject: [PATCH 1/9] Add ForwardDiff rule for solve! --- Project.toml | 3 +++ ext/LinearSolveForwardDiff.jl | 32 ++++++++++++++++++++++ src/common.jl | 9 +++++++ test/forwarddiff.jl | 51 +++++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+) create mode 100644 ext/LinearSolveForwardDiff.jl create mode 100644 test/forwarddiff.jl diff --git a/Project.toml b/Project.toml index 71f0ca822..49207f55b 100644 --- a/Project.toml +++ b/Project.toml @@ -35,6 +35,7 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" @@ -48,6 +49,7 @@ LinearSolveBandedMatricesExt = "BandedMatrices" LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveEnzymeExt = "Enzyme" +LinearSolveForwardDiff = "ForwardDiff" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" @@ -66,6 +68,7 @@ DocStringExtensions = "0.9" EnumX = "1" EnzymeCore = "0.6" FastLapackInterface = "2" +ForwardDiff = "0.10" GPUArraysCore = "0.1" HYPRE = "1.4.0" InteractiveUtils = "1.6" diff --git a/ext/LinearSolveForwardDiff.jl b/ext/LinearSolveForwardDiff.jl new file mode 100644 index 000000000..c17e150fe --- /dev/null +++ b/ext/LinearSolveForwardDiff.jl @@ -0,0 +1,32 @@ +module LinearSolveForwardDiff + +using LinearSolve +isdefined(Base, :get_extension) ? + (import ForwardDiff; using ForwardDiff: Dual) : + (import ..ForwardDiff; using ..ForwardDiff: Dual) + +function LinearSolve.solve!( + cache::LinearSolve.LinearCache{A_,B}, + alg::LinearSolve.AbstractFactorization; + kwargs... + ) where {T, V, P, A_<:AbstractArray{<:Real}, B<:AbstractArray{<:Dual{T,V,P}}} + @info "using solve! from LinearSolveForwardDiff.jl" + dA = eltype(cache.A) <: Dual ? ForwardDiff.partials.(cache.A) : zero(cache.A) + db = eltype(cache.b) <: Dual ? ForwardDiff.partials.(cache.b) : zero(cache.b) + @show typeof(cache.A) + @show typeof(cache.b) + @show typeof(cache.u) + A = eltype(cache.A) <: Dual ? ForwardDiff.value.(cache.A) : cache.A + b = eltype(cache.b) <: Dual ? ForwardDiff.value.(cache.b) : cache.b + u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u + @show typeof(A), size(A) + @show typeof(b), size(b) + @show typeof(u), size(u) + cache2 = remake(cache; A, b, u) + res = LinearSolve.solve!(cache2, alg, kwargs...) + dcache = remake(cache2; b = db - dA * res.u) + dres = LinearSolve.solve!(dcache, alg, kwargs...) + LinearSolve.SciMLBase.build_linear_solution(alg, Dual{T,V,P}.(res.u, dres.u), nothing, cache) +end + +end # module \ No newline at end of file diff --git a/src/common.jl b/src/common.jl index 791ab91c8..dc8b748b3 100644 --- a/src/common.jl +++ b/src/common.jl @@ -82,6 +82,15 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq} assumptions::OperatorAssumptions{issq} end +function SciMLBase.remake(cache::LinearCache; + A::TA=cache.A, b::TB=cache.b, u::TU=cache.u, p::TP=cache.p, alg::Talg=cache.alg, + cacheval::Tc=cache.cacheval, isfresh::Bool=cache.isfresh, Pl::Tl=cache.Pl, Pr::Tr=cache.Pr, + abstol::Ttol=cache.abstol, reltol::Ttol=cache.reltol, maxiters::Int=cache.maxiters, + verbose::Bool=cache.verbose, assumptions::OperatorAssumptions{issq}=cache.assumptions) where {TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq} + LinearCache{TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}(A,b,u,p,alg,cacheval,isfresh,Pl,Pr,abstol,reltol, + maxiters,verbose,assumptions) +end + function Base.setproperty!(cache::LinearCache, name::Symbol, x) if name === :A setfield!(cache, :isfresh, true) diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl new file mode 100644 index 000000000..1af13f6d3 --- /dev/null +++ b/test/forwarddiff.jl @@ -0,0 +1,51 @@ +using Test +using ForwardDiff +using LinearSolve +using FiniteDiff + +n = 4 +A = rand(n, n); +dA = zeros(n, n); +b1 = rand(n); +for alg in ( + LUFactorization(), + # RFLUFactorization(), + # KrylovJL_GMRES(), + ) + alg_str = string(alg) + @show alg_str + function fb(b) + prob = LinearProblem(A, b) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fb(b1) + + fid_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec + @show fid_jac + + fod_jac = ForwardDiff.gradient(fb, b1) |> vec + @show fod_jac + + @test fod_jac ≈ fid_jac rtol=1e-6 + + function fA(A) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fA(A) + + fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec + @show fid_jac + + # @test_throws MethodError fod_jac = ForwardDiff.gradient(fA, A) |> vec + fod_jac = ForwardDiff.gradient(fA, A) |> vec + # @show fod_jac + + # @test fod_jac ≈ fid_jac rtol=1e-6 +end \ No newline at end of file From a67d7aa5a5f265498e0256786f4aed1d70961b64 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 14 Nov 2023 02:43:45 +0530 Subject: [PATCH 2/9] Refactor rule --- ext/LinearSolveForwardDiff.jl | 70 +++++++++++++++++++++++++---------- test/forwarddiff.jl | 34 +++++++++++++---- 2 files changed, 77 insertions(+), 27 deletions(-) diff --git a/ext/LinearSolveForwardDiff.jl b/ext/LinearSolveForwardDiff.jl index c17e150fe..decbaeb2f 100644 --- a/ext/LinearSolveForwardDiff.jl +++ b/ext/LinearSolveForwardDiff.jl @@ -5,28 +5,60 @@ isdefined(Base, :get_extension) ? (import ForwardDiff; using ForwardDiff: Dual) : (import ..ForwardDiff; using ..ForwardDiff: Dual) +function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + @assert !(eltype(first(dAs)) isa Dual) + @assert !(eltype(first(dbs)) isa Dual) + @assert !(eltype(A) isa Dual) + @assert !(eltype(b) isa Dual) + reltol = cache.reltol isa Dual ? ForwardDiff.value(cache.reltol) : cache.reltol + abstol = cache.abstol isa Dual ? ForwardDiff.value(cache.abstol) : cache.abstol + u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u + cacheval = eltype(cache.cacheval.factors) <: Dual ? begin + LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cache.cacheval.factors), cache.cacheval.ipiv, cache.cacheval.info) + end : cache.cacheval + cache2 = remake(cache; A, b, u, reltol, abstol, cacheval) + res = LinearSolve.solve!(cache2, alg, kwargs...) + dresus = reduce(hcat, map(dAs, dbs) do dA, db + cache2.b = db - dA * res.u + dres = LinearSolve.solve!(cache2, alg, kwargs...) + deepcopy(dres.u) + end) + # display(dresus) + d = Dual{T}.(res.u, Tuple.(eachrow(dresus))) + LinearSolve.SciMLBase.build_linear_solution(alg, d, nothing, cache; retcode=res.retcode, iters=res.iters, stats=res.stats) +end + function LinearSolve.solve!( - cache::LinearSolve.LinearCache{A_,B}, + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}}, alg::LinearSolve.AbstractFactorization; kwargs... - ) where {T, V, P, A_<:AbstractArray{<:Real}, B<:AbstractArray{<:Dual{T,V,P}}} - @info "using solve! from LinearSolveForwardDiff.jl" - dA = eltype(cache.A) <: Dual ? ForwardDiff.partials.(cache.A) : zero(cache.A) - db = eltype(cache.b) <: Dual ? ForwardDiff.partials.(cache.b) : zero(cache.b) - @show typeof(cache.A) - @show typeof(cache.b) - @show typeof(cache.u) - A = eltype(cache.A) <: Dual ? ForwardDiff.value.(cache.A) : cache.A - b = eltype(cache.b) <: Dual ? ForwardDiff.value.(cache.b) : cache.b - u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u - @show typeof(A), size(A) - @show typeof(b), size(b) - @show typeof(u), size(u) - cache2 = remake(cache; A, b, u) - res = LinearSolve.solve!(cache2, alg, kwargs...) - dcache = remake(cache2; b = db - dA * res.u) - dres = LinearSolve.solve!(dcache, alg, kwargs...) - LinearSolve.SciMLBase.build_linear_solution(alg, Dual{T,V,P}.(res.u, dres.u), nothing, cache) + ) where {T, V, P} + @info "using solve! df/dA" + dAs = begin + dAs_ = ForwardDiff.partials.(cache.A) + dAs_ = collect.(dAs_) + dAs_ = [getindex.(dAs_, i) for i in 1:length(first(dAs_))] + end + dbs = [zero(cache.b) for _=1:P] + A = ForwardDiff.value.(cache.A) + b = cache.b + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) +end +function LinearSolve.solve!( + cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}}, + alg::LinearSolve.AbstractFactorization; + kwargs... + ) where {T, V, P, A_} + @info "using solve! df/db" + dAs = [zero(cache.A) for _=1:P] + dbs = begin + dbs_ = ForwardDiff.partials.(cache.b) + dbs_ = collect.(dbs_) + dbs_ = [getindex.(dbs_, i) for i in 1:length(first(dbs_))] + end + A = cache.A + b = ForwardDiff.value.(cache.b) + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) end end # module \ No newline at end of file diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index 1af13f6d3..d726b87b9 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -2,16 +2,20 @@ using Test using ForwardDiff using LinearSolve using FiniteDiff +using Enzyme +using Random +Random.seed!(1234) n = 4 A = rand(n, n); dA = zeros(n, n); b1 = rand(n); -for alg in ( - LUFactorization(), - # RFLUFactorization(), - # KrylovJL_GMRES(), - ) +alg = LUFactorization() +# for alg in ( +# LUFactorization(), +# # RFLUFactorization(), +# # KrylovJL_GMRES(), +# ) alg_str = string(alg) @show alg_str function fb(b) @@ -26,6 +30,14 @@ for alg in ( fid_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec @show fid_jac + # manual_jac = map(onehot(b1)) do db + # y = A \ b1 + # inv(A) * (db - dA*y) + # end |> collect + # display(manual_jac) + # @show sum(manual_jac) + # @show sum.(manual_jac) + fod_jac = ForwardDiff.gradient(fb, b1) |> vec @show fod_jac @@ -39,13 +51,19 @@ for alg in ( sum(sol1.u) end fA(A) + db = zero(b1) + manual_jac = map(onehot(A)) do dA + y = A \ b1 + sum(inv(A) * (db - dA*y)) + end |> collect + display(reduce(hcat, manual_jac)) fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec @show fid_jac # @test_throws MethodError fod_jac = ForwardDiff.gradient(fA, A) |> vec fod_jac = ForwardDiff.gradient(fA, A) |> vec - # @show fod_jac + @show fod_jac - # @test fod_jac ≈ fid_jac rtol=1e-6 -end \ No newline at end of file + @test fod_jac ≈ fid_jac rtol=1e-6 +# end \ No newline at end of file From 465e11c0950784568399c88006f5a95b656055a6 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 14 Nov 2023 03:29:32 +0530 Subject: [PATCH 3/9] Refactor code and add dispatch where both A and b are dual --- ext/LinearSolveForwardDiff.jl | 36 ++++++++++++++++------- test/forwarddiff.jl | 54 ++++++++++++++++++++++++++--------- 2 files changed, 66 insertions(+), 24 deletions(-) diff --git a/ext/LinearSolveForwardDiff.jl b/ext/LinearSolveForwardDiff.jl index decbaeb2f..5bf5e42e5 100644 --- a/ext/LinearSolveForwardDiff.jl +++ b/ext/LinearSolveForwardDiff.jl @@ -17,27 +17,26 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cache.cacheval.factors), cache.cacheval.ipiv, cache.cacheval.info) end : cache.cacheval cache2 = remake(cache; A, b, u, reltol, abstol, cacheval) - res = LinearSolve.solve!(cache2, alg, kwargs...) + res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy dresus = reduce(hcat, map(dAs, dbs) do dA, db cache2.b = db - dA * res.u dres = LinearSolve.solve!(cache2, alg, kwargs...) deepcopy(dres.u) end) - # display(dresus) d = Dual{T}.(res.u, Tuple.(eachrow(dresus))) LinearSolve.SciMLBase.build_linear_solution(alg, d, nothing, cache; retcode=res.retcode, iters=res.iters, stats=res.stats) end + function LinearSolve.solve!( - cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}}, + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:AbstractFloat}}, alg::LinearSolve.AbstractFactorization; kwargs... ) where {T, V, P} @info "using solve! df/dA" dAs = begin - dAs_ = ForwardDiff.partials.(cache.A) - dAs_ = collect.(dAs_) - dAs_ = [getindex.(dAs_, i) for i in 1:length(first(dAs_))] + t = collect.(ForwardDiff.partials.(cache.A)) + [getindex.(t, i) for i in 1:P] end dbs = [zero(cache.b) for _=1:P] A = ForwardDiff.value.(cache.A) @@ -45,20 +44,37 @@ function LinearSolve.solve!( _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) end function LinearSolve.solve!( - cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}}, + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:AbstractFloat},<:AbstractArray{<:Dual{T,V,P}}}, alg::LinearSolve.AbstractFactorization; kwargs... ) where {T, V, P, A_} @info "using solve! df/db" dAs = [zero(cache.A) for _=1:P] dbs = begin - dbs_ = ForwardDiff.partials.(cache.b) - dbs_ = collect.(dbs_) - dbs_ = [getindex.(dbs_, i) for i in 1:length(first(dbs_))] + t = collect.(ForwardDiff.partials.(cache.b)) + [getindex.(t, i) for i in 1:P] end A = cache.A b = ForwardDiff.value.(cache.b) _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) end +function LinearSolve.solve!( + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}}, + alg::LinearSolve.AbstractFactorization; + kwargs... + ) where {T, V, P} + @info "using solve! df/dAb" + dAs = begin + t = collect.(ForwardDiff.partials.(cache.A)) + [getindex.(t, i) for i in 1:P] + end + dbs = begin + t = collect.(ForwardDiff.partials.(cache.b)) + [getindex.(t, i) for i in 1:P] + end + A = ForwardDiff.value.(cache.A) + b = ForwardDiff.value.(cache.b) + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) +end end # module \ No newline at end of file diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index d726b87b9..1ab24ac75 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -10,12 +10,12 @@ n = 4 A = rand(n, n); dA = zeros(n, n); b1 = rand(n); -alg = LUFactorization() -# for alg in ( -# LUFactorization(), -# # RFLUFactorization(), -# # KrylovJL_GMRES(), -# ) +# alg = LUFactorization() +for alg in ( + LUFactorization(), + RFLUFactorization(), + KrylovJL_GMRES(), + ) alg_str = string(alg) @show alg_str function fb(b) @@ -51,19 +51,45 @@ alg = LUFactorization() sum(sol1.u) end fA(A) - db = zero(b1) - manual_jac = map(onehot(A)) do dA - y = A \ b1 - sum(inv(A) * (db - dA*y)) - end |> collect - display(reduce(hcat, manual_jac)) + # db = zero(b1) + # manual_jac = map(onehot(A)) do dA + # y = A \ b1 + # t = inv(A) * (db - dA*y) + # end |> collect + # display(reduce(hcat, manual_jac)) fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec @show fid_jac - # @test_throws MethodError fod_jac = ForwardDiff.gradient(fA, A) |> vec fod_jac = ForwardDiff.gradient(fA, A) |> vec @show fod_jac @test fod_jac ≈ fid_jac rtol=1e-6 -# end \ No newline at end of file + + + # function fAb(Ab) + # A = Ab[:, 1:n] + # b1 = Ab[:, n+1] + # prob = LinearProblem(A, b1) + + # sol1 = solve(prob, alg) + + # sum(sol1.u) + # end + # fAb(hcat(A, b1)) + # # db = zero(b1) + # # manual_jac = map(onehot(A)) do dA + # # y = A \ b1 + # # t = inv(A) * (db - dA*y) + # # end |> collect + # # display(reduce(hcat, manual_jac)) + + # fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec + # @show fid_jac + + # fod_jac = ForwardDiff.gradient(fAb, hcat(A, b1)) |> vec + # @show fod_jac + + # @test fod_jac ≈ fid_jac rtol=1e-6 + +end \ No newline at end of file From 1fd59f2f54884aff976d97560ddc6762432f83ef Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 19 Nov 2023 23:41:35 +0530 Subject: [PATCH 4/9] Fix multiple dispatch issues for all solvers except Krylov --- ext/LinearSolveForwardDiff.jl | 109 +++++++++++++++++++--------------- test/forwarddiff.jl | 2 +- 2 files changed, 61 insertions(+), 50 deletions(-) diff --git a/ext/LinearSolveForwardDiff.jl b/ext/LinearSolveForwardDiff.jl index 5bf5e42e5..bec3421d0 100644 --- a/ext/LinearSolveForwardDiff.jl +++ b/ext/LinearSolveForwardDiff.jl @@ -1,6 +1,7 @@ module LinearSolveForwardDiff using LinearSolve +using InteractiveUtils isdefined(Base, :get_extension) ? (import ForwardDiff; using ForwardDiff: Dual) : (import ..ForwardDiff; using ..ForwardDiff: Dual) @@ -13,9 +14,15 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) reltol = cache.reltol isa Dual ? ForwardDiff.value(cache.reltol) : cache.reltol abstol = cache.abstol isa Dual ? ForwardDiff.value(cache.abstol) : cache.abstol u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u - cacheval = eltype(cache.cacheval.factors) <: Dual ? begin - LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cache.cacheval.factors), cache.cacheval.ipiv, cache.cacheval.info) - end : cache.cacheval + @show typeof(cache.cacheval) + @show cache.cacheval isa Tuple + cacheval = cache.cacheval isa Tuple ? cache.cacheval[1] : cache.cacheval + @show typeof(cacheval) + cacheval = eltype(cacheval.factors) <: Dual ? begin + LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cacheval.factors), cacheval.ipiv, cacheval.info) + end : cacheval + cacheval = cache.cacheval isa Tuple ? (cacheval, cache.cacheval[2]) : cacheval + cache2 = remake(cache; A, b, u, reltol, abstol, cacheval) res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy dresus = reduce(hcat, map(dAs, dbs) do dA, db @@ -28,53 +35,57 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) end -function LinearSolve.solve!( - cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:AbstractFloat}}, - alg::LinearSolve.AbstractFactorization; - kwargs... - ) where {T, V, P} - @info "using solve! df/dA" - dAs = begin - t = collect.(ForwardDiff.partials.(cache.A)) - [getindex.(t, i) for i in 1:P] - end - dbs = [zero(cache.b) for _=1:P] - A = ForwardDiff.value.(cache.A) - b = cache.b - _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) -end -function LinearSolve.solve!( - cache::LinearSolve.LinearCache{<:AbstractMatrix{<:AbstractFloat},<:AbstractArray{<:Dual{T,V,P}}}, - alg::LinearSolve.AbstractFactorization; - kwargs... - ) where {T, V, P, A_} - @info "using solve! df/db" - dAs = [zero(cache.A) for _=1:P] - dbs = begin - t = collect.(ForwardDiff.partials.(cache.b)) - [getindex.(t, i) for i in 1:P] - end - A = cache.A - b = ForwardDiff.value.(cache.b) - _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) -end -function LinearSolve.solve!( - cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}}, - alg::LinearSolve.AbstractFactorization; - kwargs... - ) where {T, V, P} - @info "using solve! df/dAb" - dAs = begin - t = collect.(ForwardDiff.partials.(cache.A)) - [getindex.(t, i) for i in 1:P] - end - dbs = begin - t = collect.(ForwardDiff.partials.(cache.b)) - [getindex.(t, i) for i in 1:P] +for ALG in subtypes(LinearSolve, LinearSolve.AbstractFactorization) + @eval begin + function LinearSolve.solve!( + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}, B}, + alg::$ALG, + kwargs... + ) where {T, V, P, B} + @info "using solve! df/dA" + dAs = begin + t = collect.(ForwardDiff.partials.(cache.A)) + [getindex.(t, i) for i in 1:P] + end + dbs = [zero(cache.b) for _=1:P] + A = ForwardDiff.value.(cache.A) + b = cache.b + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + end + function LinearSolve.solve!( + cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}}, + alg::$ALG; + kwargs... + ) where {T, V, P, A_} + @info "using solve! df/db" + dAs = [zero(cache.A) for _=1:P] + dbs = begin + t = collect.(ForwardDiff.partials.(cache.b)) + [getindex.(t, i) for i in 1:P] + end + A = cache.A + b = ForwardDiff.value.(cache.b) + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + end + function LinearSolve.solve!( + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}}, + alg::$ALG; + kwargs... + ) where {T, V, P} + @info "using solve! df/dAb" + dAs = begin + t = collect.(ForwardDiff.partials.(cache.A)) + [getindex.(t, i) for i in 1:P] + end + dbs = begin + t = collect.(ForwardDiff.partials.(cache.b)) + [getindex.(t, i) for i in 1:P] + end + A = ForwardDiff.value.(cache.A) + b = ForwardDiff.value.(cache.b) + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + end end - A = ForwardDiff.value.(cache.A) - b = ForwardDiff.value.(cache.b) - _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) end end # module \ No newline at end of file diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index 1ab24ac75..5c3fa6cd7 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -14,7 +14,7 @@ b1 = rand(n); for alg in ( LUFactorization(), RFLUFactorization(), - KrylovJL_GMRES(), + # KrylovJL_GMRES(), ) alg_str = string(alg) @show alg_str From 465bd47432a59f3b56d44520f60200c4424e8982 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 19 Nov 2023 23:42:22 +0530 Subject: [PATCH 5/9] Remove debug statements --- ext/LinearSolveForwardDiff.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/ext/LinearSolveForwardDiff.jl b/ext/LinearSolveForwardDiff.jl index bec3421d0..f3f316507 100644 --- a/ext/LinearSolveForwardDiff.jl +++ b/ext/LinearSolveForwardDiff.jl @@ -14,10 +14,7 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) reltol = cache.reltol isa Dual ? ForwardDiff.value(cache.reltol) : cache.reltol abstol = cache.abstol isa Dual ? ForwardDiff.value(cache.abstol) : cache.abstol u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u - @show typeof(cache.cacheval) - @show cache.cacheval isa Tuple cacheval = cache.cacheval isa Tuple ? cache.cacheval[1] : cache.cacheval - @show typeof(cacheval) cacheval = eltype(cacheval.factors) <: Dual ? begin LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cacheval.factors), cacheval.ipiv, cacheval.info) end : cacheval From a85634de39d34a3c7615115c8ad374acbdd864ab Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 19 Nov 2023 23:43:01 +0530 Subject: [PATCH 6/9] Remove Enable fAB tests --- test/forwarddiff.jl | 48 ++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index 5c3fa6cd7..a8e12f387 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -67,29 +67,29 @@ for alg in ( @test fod_jac ≈ fid_jac rtol=1e-6 - # function fAb(Ab) - # A = Ab[:, 1:n] - # b1 = Ab[:, n+1] - # prob = LinearProblem(A, b1) - - # sol1 = solve(prob, alg) - - # sum(sol1.u) - # end - # fAb(hcat(A, b1)) - # # db = zero(b1) - # # manual_jac = map(onehot(A)) do dA - # # y = A \ b1 - # # t = inv(A) * (db - dA*y) - # # end |> collect - # # display(reduce(hcat, manual_jac)) - - # fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec - # @show fid_jac - - # fod_jac = ForwardDiff.gradient(fAb, hcat(A, b1)) |> vec - # @show fod_jac - - # @test fod_jac ≈ fid_jac rtol=1e-6 + function fAb(Ab) + A = Ab[:, 1:n] + b1 = Ab[:, n+1] + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fAb(hcat(A, b1)) + # db = zero(b1) + # manual_jac = map(onehot(A)) do dA + # y = A \ b1 + # t = inv(A) * (db - dA*y) + # end |> collect + # display(reduce(hcat, manual_jac)) + + fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec + @show fid_jac + + fod_jac = ForwardDiff.gradient(fAb, hcat(A, b1)) |> vec + @show fod_jac + + @test fod_jac ≈ fid_jac rtol=1e-6 end \ No newline at end of file From 834cbdd06b263c78e5866f855855d24b8f527220 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 19 Nov 2023 23:43:56 +0530 Subject: [PATCH 7/9] Cleanup tests --- test/forwarddiff.jl | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index a8e12f387..14c5cc657 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -14,7 +14,7 @@ b1 = rand(n); for alg in ( LUFactorization(), RFLUFactorization(), - # KrylovJL_GMRES(), + # KrylovJL_GMRES(), dispatch fails ) alg_str = string(alg) @show alg_str @@ -30,14 +30,6 @@ for alg in ( fid_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec @show fid_jac - # manual_jac = map(onehot(b1)) do db - # y = A \ b1 - # inv(A) * (db - dA*y) - # end |> collect - # display(manual_jac) - # @show sum(manual_jac) - # @show sum.(manual_jac) - fod_jac = ForwardDiff.gradient(fb, b1) |> vec @show fod_jac @@ -51,12 +43,6 @@ for alg in ( sum(sol1.u) end fA(A) - # db = zero(b1) - # manual_jac = map(onehot(A)) do dA - # y = A \ b1 - # t = inv(A) * (db - dA*y) - # end |> collect - # display(reduce(hcat, manual_jac)) fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec @show fid_jac @@ -77,12 +63,6 @@ for alg in ( sum(sol1.u) end fAb(hcat(A, b1)) - # db = zero(b1) - # manual_jac = map(onehot(A)) do dA - # y = A \ b1 - # t = inv(A) * (db - dA*y) - # end |> collect - # display(reduce(hcat, manual_jac)) fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec @show fid_jac From 1a18393326f15a5438bc5ace495a579c721dd9ec Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 19 Nov 2023 23:44:24 +0530 Subject: [PATCH 8/9] Disable debug statements --- ext/LinearSolveForwardDiff.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/LinearSolveForwardDiff.jl b/ext/LinearSolveForwardDiff.jl index f3f316507..4b386889c 100644 --- a/ext/LinearSolveForwardDiff.jl +++ b/ext/LinearSolveForwardDiff.jl @@ -39,7 +39,7 @@ for ALG in subtypes(LinearSolve, LinearSolve.AbstractFactorization) alg::$ALG, kwargs... ) where {T, V, P, B} - @info "using solve! df/dA" + # @info "using solve! df/dA" dAs = begin t = collect.(ForwardDiff.partials.(cache.A)) [getindex.(t, i) for i in 1:P] @@ -54,7 +54,7 @@ for ALG in subtypes(LinearSolve, LinearSolve.AbstractFactorization) alg::$ALG; kwargs... ) where {T, V, P, A_} - @info "using solve! df/db" + # @info "using solve! df/db" dAs = [zero(cache.A) for _=1:P] dbs = begin t = collect.(ForwardDiff.partials.(cache.b)) @@ -69,7 +69,7 @@ for ALG in subtypes(LinearSolve, LinearSolve.AbstractFactorization) alg::$ALG; kwargs... ) where {T, V, P} - @info "using solve! df/dAb" + # @info "using solve! df/dAb" dAs = begin t = collect.(ForwardDiff.partials.(cache.A)) [getindex.(t, i) for i in 1:P] From 829a914a80c4b63cc423b84d203b95546c268fbd Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 19 Nov 2023 23:45:30 +0530 Subject: [PATCH 9/9] Cleanup tests --- test/forwarddiff.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index 14c5cc657..61568d262 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -10,7 +10,6 @@ n = 4 A = rand(n, n); dA = zeros(n, n); b1 = rand(n); -# alg = LUFactorization() for alg in ( LUFactorization(), RFLUFactorization(),