diff --git a/Project.toml b/Project.toml index 538cd9594..b6f09dac0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "2.27.0" +version = "3.0.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -14,6 +14,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" KLU = "ef3ab10e-7fda-4108-b977-705223b18434" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" +LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" @@ -85,6 +86,7 @@ KLU = "0.6" KernelAbstractions = "0.9.16" Krylov = "0.9" KrylovKit = "0.6" +LazyArrays = "1" Libdl = "1.10" LinearAlgebra = "1.10" MPI = "0.20" diff --git a/docs/Project.toml b/docs/Project.toml index 05f63912d..5934dd13c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,4 +4,4 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" [compat] Documenter = "1" -LinearSolve = "1, 2" +LinearSolve = "1, 2, 3" diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index ce27a5abf..e274521c9 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -13,6 +13,7 @@ PrecompileTools.@recompile_invalidations begin using LinearAlgebra using SparseArrays using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr + using LazyArrays: @~, BroadcastArray using SciMLBase: AbstractLinearAlgorithm using SciMLOperators using SciMLOperators: AbstractSciMLOperator, IdentityOperator diff --git a/src/adjoint.jl b/src/adjoint.jl index 550bb2bd6..f5034a736 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -76,7 +76,8 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem, invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u end - ∂A = -λ * transpose(sol.u) + tu = transpose(sol.u) + ∂A = BroadcastArray(@~ .-(λ .* tu)) ∂b = λ ∂prob = LinearProblem(∂A, ∂b, ∂∅) diff --git a/test/adjoint.jl b/test/adjoint.jl index 4478daf98..673f81b62 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -1,6 +1,7 @@ using Zygote, ForwardDiff using LinearSolve, LinearAlgebra, Test using FiniteDiff +using LazyArrays: BroadcastArray n = 4 A = rand(n, n); @@ -18,6 +19,7 @@ end f(A, b1) # Uses BLAS dA, db1 = Zygote.gradient(f, A, b1) +@test dA isa BroadcastArray dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A)) db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1)) @@ -34,6 +36,7 @@ _ff = (x, y) -> f(x, _ff(copy(A), copy(b1)) dA, db1 = Zygote.gradient(_ff, copy(A), copy(b1)) +@test dA isa BroadcastArray dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A)) db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1)) @@ -50,6 +53,7 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES()) end dA, db1, db2 = Zygote.gradient(f3, A, b1, b1) +@test dA isa BroadcastArray dA2 = FiniteDiff.finite_difference_gradient( x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) @@ -71,6 +75,7 @@ function f4(A, b1, b2; alg = LUFactorization()) end dA, db1, db2 = Zygote.gradient(f4, A, b1, b1) +@test dA isa BroadcastArray dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))