Skip to content

Commit

Permalink
Fix multiple dispatch issues for all solvers except Krylov
Browse files Browse the repository at this point in the history
  • Loading branch information
sharanry committed Nov 19, 2023
1 parent 465e11c commit 1fd59f2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 50 deletions.
109 changes: 60 additions & 49 deletions ext/LinearSolveForwardDiff.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LinearSolveForwardDiff

using LinearSolve
using InteractiveUtils
isdefined(Base, :get_extension) ?
(import ForwardDiff; using ForwardDiff: Dual) :
(import ..ForwardDiff; using ..ForwardDiff: Dual)
Expand All @@ -13,9 +14,15 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
reltol = cache.reltol isa Dual ? ForwardDiff.value(cache.reltol) : cache.reltol
abstol = cache.abstol isa Dual ? ForwardDiff.value(cache.abstol) : cache.abstol
u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u
cacheval = eltype(cache.cacheval.factors) <: Dual ? begin
LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cache.cacheval.factors), cache.cacheval.ipiv, cache.cacheval.info)
end : cache.cacheval
@show typeof(cache.cacheval)
@show cache.cacheval isa Tuple
cacheval = cache.cacheval isa Tuple ? cache.cacheval[1] : cache.cacheval
@show typeof(cacheval)
cacheval = eltype(cacheval.factors) <: Dual ? begin
LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cacheval.factors), cacheval.ipiv, cacheval.info)

Check warning on line 22 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L9-L22

Added lines #L9 - L22 were not covered by tests
end : cacheval
cacheval = cache.cacheval isa Tuple ? (cacheval, cache.cacheval[2]) : cacheval

Check warning on line 24 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L24

Added line #L24 was not covered by tests

cache2 = remake(cache; A, b, u, reltol, abstol, cacheval)
res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy
dresus = reduce(hcat, map(dAs, dbs) do dA, db
Expand All @@ -28,53 +35,57 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
end


function LinearSolve.solve!(
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:AbstractFloat}},
alg::LinearSolve.AbstractFactorization;
kwargs...
) where {T, V, P}
@info "using solve! df/dA"
dAs = begin
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]
end
dbs = [zero(cache.b) for _=1:P]
A = ForwardDiff.value.(cache.A)
b = cache.b
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
end
function LinearSolve.solve!(
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:AbstractFloat},<:AbstractArray{<:Dual{T,V,P}}},
alg::LinearSolve.AbstractFactorization;
kwargs...
) where {T, V, P, A_}
@info "using solve! df/db"
dAs = [zero(cache.A) for _=1:P]
dbs = begin
t = collect.(ForwardDiff.partials.(cache.b))
[getindex.(t, i) for i in 1:P]
end
A = cache.A
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
end
function LinearSolve.solve!(
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}},
alg::LinearSolve.AbstractFactorization;
kwargs...
) where {T, V, P}
@info "using solve! df/dAb"
dAs = begin
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]
end
dbs = begin
t = collect.(ForwardDiff.partials.(cache.b))
[getindex.(t, i) for i in 1:P]
for ALG in subtypes(LinearSolve, LinearSolve.AbstractFactorization)
@eval begin
function LinearSolve.solve!(

Check warning on line 40 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L40

Added line #L40 was not covered by tests
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}, B},
alg::$ALG,
kwargs...
) where {T, V, P, B}
@info "using solve! df/dA"
dAs = begin
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]

Check warning on line 48 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L45-L48

Added lines #L45 - L48 were not covered by tests
end
dbs = [zero(cache.b) for _=1:P]
A = ForwardDiff.value.(cache.A)
b = cache.b
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)

Check warning on line 53 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L50-L53

Added lines #L50 - L53 were not covered by tests
end
function LinearSolve.solve!(

Check warning on line 55 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L55

Added line #L55 was not covered by tests
cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}},
alg::$ALG;
kwargs...
) where {T, V, P, A_}
@info "using solve! df/db"
dAs = [zero(cache.A) for _=1:P]
dbs = begin
t = collect.(ForwardDiff.partials.(cache.b))
[getindex.(t, i) for i in 1:P]

Check warning on line 64 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L60-L64

Added lines #L60 - L64 were not covered by tests
end
A = cache.A
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)

Check warning on line 68 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L66-L68

Added lines #L66 - L68 were not covered by tests
end
function LinearSolve.solve!(

Check warning on line 70 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L70

Added line #L70 was not covered by tests
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}},
alg::$ALG;
kwargs...
) where {T, V, P}
@info "using solve! df/dAb"
dAs = begin
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]

Check warning on line 78 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L75-L78

Added lines #L75 - L78 were not covered by tests
end
dbs = begin
t = collect.(ForwardDiff.partials.(cache.b))
[getindex.(t, i) for i in 1:P]

Check warning on line 82 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L80-L82

Added lines #L80 - L82 were not covered by tests
end
A = ForwardDiff.value.(cache.A)
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)

Check warning on line 86 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L84-L86

Added lines #L84 - L86 were not covered by tests
end
end
A = ForwardDiff.value.(cache.A)
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
end

end # module
2 changes: 1 addition & 1 deletion test/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ b1 = rand(n);
for alg in (
LUFactorization(),
RFLUFactorization(),
KrylovJL_GMRES(),
# KrylovJL_GMRES(),
)
alg_str = string(alg)
@show alg_str
Expand Down

0 comments on commit 1fd59f2

Please sign in to comment.