From 1fd59f2f54884aff976d97560ddc6762432f83ef Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 19 Nov 2023 23:41:35 +0530 Subject: [PATCH] Fix multiple dispatch issues for all solvers except Krylov --- ext/LinearSolveForwardDiff.jl | 109 +++++++++++++++++++--------------- test/forwarddiff.jl | 2 +- 2 files changed, 61 insertions(+), 50 deletions(-) diff --git a/ext/LinearSolveForwardDiff.jl b/ext/LinearSolveForwardDiff.jl index 5bf5e42e5..bec3421d0 100644 --- a/ext/LinearSolveForwardDiff.jl +++ b/ext/LinearSolveForwardDiff.jl @@ -1,6 +1,7 @@ module LinearSolveForwardDiff using LinearSolve +using InteractiveUtils isdefined(Base, :get_extension) ? (import ForwardDiff; using ForwardDiff: Dual) : (import ..ForwardDiff; using ..ForwardDiff: Dual) @@ -13,9 +14,15 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) reltol = cache.reltol isa Dual ? ForwardDiff.value(cache.reltol) : cache.reltol abstol = cache.abstol isa Dual ? ForwardDiff.value(cache.abstol) : cache.abstol u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u - cacheval = eltype(cache.cacheval.factors) <: Dual ? begin - LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cache.cacheval.factors), cache.cacheval.ipiv, cache.cacheval.info) - end : cache.cacheval + @show typeof(cache.cacheval) + @show cache.cacheval isa Tuple + cacheval = cache.cacheval isa Tuple ? cache.cacheval[1] : cache.cacheval + @show typeof(cacheval) + cacheval = eltype(cacheval.factors) <: Dual ? begin + LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cacheval.factors), cacheval.ipiv, cacheval.info) + end : cacheval + cacheval = cache.cacheval isa Tuple ? (cacheval, cache.cacheval[2]) : cacheval + cache2 = remake(cache; A, b, u, reltol, abstol, cacheval) res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy dresus = reduce(hcat, map(dAs, dbs) do dA, db @@ -28,53 +35,57 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) end -function LinearSolve.solve!( - cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:AbstractFloat}}, - alg::LinearSolve.AbstractFactorization; - kwargs... - ) where {T, V, P} - @info "using solve! df/dA" - dAs = begin - t = collect.(ForwardDiff.partials.(cache.A)) - [getindex.(t, i) for i in 1:P] - end - dbs = [zero(cache.b) for _=1:P] - A = ForwardDiff.value.(cache.A) - b = cache.b - _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) -end -function LinearSolve.solve!( - cache::LinearSolve.LinearCache{<:AbstractMatrix{<:AbstractFloat},<:AbstractArray{<:Dual{T,V,P}}}, - alg::LinearSolve.AbstractFactorization; - kwargs... - ) where {T, V, P, A_} - @info "using solve! df/db" - dAs = [zero(cache.A) for _=1:P] - dbs = begin - t = collect.(ForwardDiff.partials.(cache.b)) - [getindex.(t, i) for i in 1:P] - end - A = cache.A - b = ForwardDiff.value.(cache.b) - _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) -end -function LinearSolve.solve!( - cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}}, - alg::LinearSolve.AbstractFactorization; - kwargs... - ) where {T, V, P} - @info "using solve! df/dAb" - dAs = begin - t = collect.(ForwardDiff.partials.(cache.A)) - [getindex.(t, i) for i in 1:P] - end - dbs = begin - t = collect.(ForwardDiff.partials.(cache.b)) - [getindex.(t, i) for i in 1:P] +for ALG in subtypes(LinearSolve, LinearSolve.AbstractFactorization) + @eval begin + function LinearSolve.solve!( + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}, B}, + alg::$ALG, + kwargs... + ) where {T, V, P, B} + @info "using solve! df/dA" + dAs = begin + t = collect.(ForwardDiff.partials.(cache.A)) + [getindex.(t, i) for i in 1:P] + end + dbs = [zero(cache.b) for _=1:P] + A = ForwardDiff.value.(cache.A) + b = cache.b + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + end + function LinearSolve.solve!( + cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}}, + alg::$ALG; + kwargs... + ) where {T, V, P, A_} + @info "using solve! df/db" + dAs = [zero(cache.A) for _=1:P] + dbs = begin + t = collect.(ForwardDiff.partials.(cache.b)) + [getindex.(t, i) for i in 1:P] + end + A = cache.A + b = ForwardDiff.value.(cache.b) + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + end + function LinearSolve.solve!( + cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}}, + alg::$ALG; + kwargs... + ) where {T, V, P} + @info "using solve! df/dAb" + dAs = begin + t = collect.(ForwardDiff.partials.(cache.A)) + [getindex.(t, i) for i in 1:P] + end + dbs = begin + t = collect.(ForwardDiff.partials.(cache.b)) + [getindex.(t, i) for i in 1:P] + end + A = ForwardDiff.value.(cache.A) + b = ForwardDiff.value.(cache.b) + _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) + end end - A = ForwardDiff.value.(cache.A) - b = ForwardDiff.value.(cache.b) - _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) end end # module \ No newline at end of file diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index 1ab24ac75..5c3fa6cd7 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -14,7 +14,7 @@ b1 = rand(n); for alg in ( LUFactorization(), RFLUFactorization(), - KrylovJL_GMRES(), + # KrylovJL_GMRES(), ) alg_str = string(alg) @show alg_str