Skip to content

Commit

Permalink
Merge pull request #484 from mohamed82008/mt/lazy_rrule
Browse files Browse the repository at this point in the history
Make the rrule's outer product lazy
  • Loading branch information
ChrisRackauckas authored Mar 23, 2024
2 parents c08f2e9 + 8d0fd26 commit 96cefaf
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 3 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.27.0"
version = "3.0.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -14,6 +14,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Expand Down Expand Up @@ -85,6 +86,7 @@ KLU = "0.6"
KernelAbstractions = "0.9.16"
Krylov = "0.9"
KrylovKit = "0.6"
LazyArrays = "1"
Libdl = "1.10"
LinearAlgebra = "1.10"
MPI = "0.20"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"

[compat]
Documenter = "1"
LinearSolve = "1, 2"
LinearSolve = "1, 2, 3"
1 change: 1 addition & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ PrecompileTools.@recompile_invalidations begin
using LinearAlgebra
using SparseArrays
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
using LazyArrays: @~, BroadcastArray
using SciMLBase: AbstractLinearAlgorithm
using SciMLOperators
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
Expand Down
3 changes: 2 additions & 1 deletion src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

∂A = -λ * transpose(sol.u)
tu = transpose(sol.u)
∂A = BroadcastArray(@~ .-.* tu))
∂b = λ
∂prob = LinearProblem(∂A, ∂b, ∂∅)

Expand Down
5 changes: 5 additions & 0 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Zygote, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff
using LazyArrays: BroadcastArray

n = 4
A = rand(n, n);
Expand All @@ -18,6 +19,7 @@ end
f(A, b1) # Uses BLAS

dA, db1 = Zygote.gradient(f, A, b1)
@test dA isa BroadcastArray

dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
Expand All @@ -34,6 +36,7 @@ _ff = (x, y) -> f(x,
_ff(copy(A), copy(b1))

dA, db1 = Zygote.gradient(_ff, copy(A), copy(b1))
@test dA isa BroadcastArray

dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
Expand All @@ -50,6 +53,7 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES())
end

dA, db1, db2 = Zygote.gradient(f3, A, b1, b1)
@test dA isa BroadcastArray

dA2 = FiniteDiff.finite_difference_gradient(
x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
Expand All @@ -71,6 +75,7 @@ function f4(A, b1, b2; alg = LUFactorization())
end

dA, db1, db2 = Zygote.gradient(f4, A, b1, b1)
@test dA isa BroadcastArray

dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
Expand Down

0 comments on commit 96cefaf

Please sign in to comment.