-
-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ForwardDiff rules #434
Changes from all commits
c3b72c9
a67d7aa
465e11c
1fd59f2
465bd47
a85634d
834cbdd
1a18393
829a914
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
module LinearSolveForwardDiff | ||
|
||
using LinearSolve | ||
using InteractiveUtils | ||
isdefined(Base, :get_extension) ? | ||
(import ForwardDiff; using ForwardDiff: Dual) : | ||
(import ..ForwardDiff; using ..ForwardDiff: Dual) | ||
|
||
function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) | ||
@assert !(eltype(first(dAs)) isa Dual) | ||
@assert !(eltype(first(dbs)) isa Dual) | ||
@assert !(eltype(A) isa Dual) | ||
@assert !(eltype(b) isa Dual) | ||
reltol = cache.reltol isa Dual ? ForwardDiff.value(cache.reltol) : cache.reltol | ||
abstol = cache.abstol isa Dual ? ForwardDiff.value(cache.abstol) : cache.abstol | ||
u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u | ||
cacheval = cache.cacheval isa Tuple ? cache.cacheval[1] : cache.cacheval | ||
cacheval = eltype(cacheval.factors) <: Dual ? begin | ||
LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cacheval.factors), cacheval.ipiv, cacheval.info) | ||
end : cacheval | ||
cacheval = cache.cacheval isa Tuple ? (cacheval, cache.cacheval[2]) : cacheval | ||
|
||
cache2 = remake(cache; A, b, u, reltol, abstol, cacheval) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Being forced to remake cache in order to solve the non-dual version. Is there some other way we can replace Dual Array with a regular array? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you want to hook into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or rather, it's just one |
||
res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy | ||
dresus = reduce(hcat, map(dAs, dbs) do dA, db | ||
cache2.b = db - dA * res.u | ||
dres = LinearSolve.solve!(cache2, alg, kwargs...) | ||
deepcopy(dres.u) | ||
end) | ||
Comment on lines
+24
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needing to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if you hook into init and do a single batched solve then this is handled. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any documentation on how to do batched solves? I am unable to find how to do this anywhere. The possi bly closest thing I could find was https://discourse.julialang.org/t/batched-lu-solves-or-factorizations-with-sparse-matrices/106019/2 -- however, couldn't find the right function call. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not entirely sure what you mean in the context of LinearSolve.jl. n = 4
A = rand(n, n)
B = rand(n, n)
A \ B # works
mapreduce(hcat, eachcol(B)) do b
A \ b
end # works
mapreduce(hcat, eachcol(B)) do b
prob = LinearProblem(A, b)
sol = solve(prob)
sol.u
end # works
begin
prob = LinearProblem(A, B)
sol = solve(prob) # errors
sol.u
end
Error: ERROR: MethodError: no method matching ldiv!(::Vector{Float64}, ::LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, ::Matrix{Float64})
Closest candidates are:
ldiv!(::Any, ::Sparspak.SpkSparseSolver.SparseSolver{IT, FT}, ::Any) where {IT, FT}
@ Sparspak ~/.julia/packages/Sparspak/oqBYl/src/SparseCSCInterface/SparseCSCInterface.jl:263
ldiv!(::Any, ::LinearSolve.InvPreconditioner, ::Any)
@ LinearSolve ~/code/enzyme_playground/LS_FD/src/preconditioners.jl:30
ldiv!(::Any, ::LinearSolve.ComposePreconditioner, ::Any)
@ LinearSolve ~/code/enzyme_playground/LS_FD/src/preconditioners.jl:17
...
Stacktrace:
[1] _ldiv!(x::Vector{Float64}, A::LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, b::Matrix{Float64})
@ LinearSolve ~/code/enzyme_playground/LS_FD/src/factorization.jl:11
[2] macro expansion
@ ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:135 [inlined]
[3] solve!(cache::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, alg::LUFactorization{LinearAlgebra.RowMaximum}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ LinearSolve ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:127
[4] solve!(cache::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, alg::LUFactorization{LinearAlgebra.RowMaximum})
@ LinearSolve ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:127
[5] solve!(::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:218
[6] solve!(::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool})
@ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:217
[7] solve(::LinearProblem{Nothing, true, Matrix{Float64}, Matrix{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::LUFactorization{LinearAlgebra.RowMaximum}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:214
[8] solve(::LinearProblem{Nothing, true, Matrix{Float64}, Matrix{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::LUFactorization{LinearAlgebra.RowMaximum})
@ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:211
[9] top-level scope
@ REPL[24]:3 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @avik-pal I thought you handled something with this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @avik-pal A ping on this. Is there another way to do this if we do not yet have batch dispatch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not for this case, but a case where |
||
d = Dual{T}.(res.u, Tuple.(eachrow(dresus))) | ||
LinearSolve.SciMLBase.build_linear_solution(alg, d, nothing, cache; retcode=res.retcode, iters=res.iters, stats=res.stats) | ||
end | ||
|
||
|
||
for ALG in subtypes(LinearSolve, LinearSolve.AbstractFactorization) | ||
@eval begin | ||
function LinearSolve.solve!( | ||
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}, B}, | ||
alg::$ALG, | ||
kwargs... | ||
) where {T, V, P, B} | ||
# @info "using solve! df/dA" | ||
dAs = begin | ||
t = collect.(ForwardDiff.partials.(cache.A)) | ||
[getindex.(t, i) for i in 1:P] | ||
end | ||
dbs = [zero(cache.b) for _=1:P] | ||
A = ForwardDiff.value.(cache.A) | ||
b = cache.b | ||
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) | ||
end | ||
function LinearSolve.solve!( | ||
cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}}, | ||
alg::$ALG; | ||
kwargs... | ||
) where {T, V, P, A_} | ||
# @info "using solve! df/db" | ||
dAs = [zero(cache.A) for _=1:P] | ||
dbs = begin | ||
t = collect.(ForwardDiff.partials.(cache.b)) | ||
[getindex.(t, i) for i in 1:P] | ||
end | ||
A = cache.A | ||
b = ForwardDiff.value.(cache.b) | ||
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) | ||
end | ||
function LinearSolve.solve!( | ||
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}}, | ||
alg::$ALG; | ||
kwargs... | ||
) where {T, V, P} | ||
# @info "using solve! df/dAb" | ||
dAs = begin | ||
t = collect.(ForwardDiff.partials.(cache.A)) | ||
[getindex.(t, i) for i in 1:P] | ||
end | ||
dbs = begin | ||
t = collect.(ForwardDiff.partials.(cache.b)) | ||
[getindex.(t, i) for i in 1:P] | ||
end | ||
A = ForwardDiff.value.(cache.A) | ||
b = ForwardDiff.value.(cache.b) | ||
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...) | ||
end | ||
end | ||
end | ||
|
||
end # module |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,6 +82,15 @@ | |
assumptions::OperatorAssumptions{issq} | ||
end | ||
|
||
function SciMLBase.remake(cache::LinearCache; | ||
A::TA=cache.A, b::TB=cache.b, u::TU=cache.u, p::TP=cache.p, alg::Talg=cache.alg, | ||
cacheval::Tc=cache.cacheval, isfresh::Bool=cache.isfresh, Pl::Tl=cache.Pl, Pr::Tr=cache.Pr, | ||
abstol::Ttol=cache.abstol, reltol::Ttol=cache.reltol, maxiters::Int=cache.maxiters, | ||
verbose::Bool=cache.verbose, assumptions::OperatorAssumptions{issq}=cache.assumptions) where {TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq} | ||
LinearCache{TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}(A,b,u,p,alg,cacheval,isfresh,Pl,Pr,abstol,reltol, | ||
maxiters,verbose,assumptions) | ||
end | ||
|
||
Comment on lines
+85
to
+93
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to check if there is a way to avoid redefining this by providing a better constructor for |
||
function Base.setproperty!(cache::LinearCache, name::Symbol, x) | ||
if name === :A | ||
setfield!(cache, :isfresh, true) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
using Test | ||
using ForwardDiff | ||
using LinearSolve | ||
using FiniteDiff | ||
using Enzyme | ||
using Random | ||
Random.seed!(1234) | ||
|
||
n = 4 | ||
A = rand(n, n); | ||
dA = zeros(n, n); | ||
b1 = rand(n); | ||
for alg in ( | ||
LUFactorization(), | ||
RFLUFactorization(), | ||
# KrylovJL_GMRES(), dispatch fails | ||
) | ||
alg_str = string(alg) | ||
@show alg_str | ||
function fb(b) | ||
prob = LinearProblem(A, b) | ||
|
||
sol1 = solve(prob, alg) | ||
|
||
sum(sol1.u) | ||
end | ||
fb(b1) | ||
|
||
fid_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec | ||
@show fid_jac | ||
|
||
fod_jac = ForwardDiff.gradient(fb, b1) |> vec | ||
@show fod_jac | ||
|
||
@test fod_jac ≈ fid_jac rtol=1e-6 | ||
|
||
function fA(A) | ||
prob = LinearProblem(A, b1) | ||
|
||
sol1 = solve(prob, alg) | ||
|
||
sum(sol1.u) | ||
end | ||
fA(A) | ||
|
||
fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec | ||
@show fid_jac | ||
|
||
fod_jac = ForwardDiff.gradient(fA, A) |> vec | ||
@show fod_jac | ||
|
||
@test fod_jac ≈ fid_jac rtol=1e-6 | ||
|
||
|
||
function fAb(Ab) | ||
A = Ab[:, 1:n] | ||
b1 = Ab[:, n+1] | ||
prob = LinearProblem(A, b1) | ||
|
||
sol1 = solve(prob, alg) | ||
|
||
sum(sol1.u) | ||
end | ||
fAb(hcat(A, b1)) | ||
|
||
fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec | ||
@show fid_jac | ||
|
||
fod_jac = ForwardDiff.gradient(fAb, hcat(A, b1)) |> vec | ||
@show fod_jac | ||
|
||
@test fod_jac ≈ fid_jac rtol=1e-6 | ||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only 1.9+ is supported now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand. What do you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically you dont need to do this anymore, just the first import line works