diff --git a/Project.toml b/Project.toml index 2f70bb7ac..7b8b7204b 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index d7e0082d0..1c94580e7 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -9,6 +9,7 @@ using LinearAlgebra.BLAS using Random using Requires using Statistics +using SparseArrays # Basically everything this package does is overloading these, so we make an exception # to the normal rule of only overload via `ChainRulesCore.rrule`. @@ -37,6 +38,8 @@ include("rulesets/LinearAlgebra/dense.jl") include("rulesets/LinearAlgebra/structured.jl") include("rulesets/LinearAlgebra/factorization.jl") +include("rulesets/SparseArrays/sparsematrix.jl") + include("rulesets/Random/random.jl") # Note: The following is only required because package authors sometimes do not diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 9800363c1..5affa37e1 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -332,3 +332,13 @@ function rrule(::typeof(norm), x::Real, p::Real=2) end return norm(x, p), norm_pullback end + +function rrule(::typeof(diagm), x::AbstractVector) + diagm_pullback(∂x) = (NO_FIELDS, diag(∂x),) + return diagm(x), diagm_pullback +end + +function rrule(::typeof(issymmetric), x) + issymmetric_pullback(∂x) = (NO_FIELDS, ∂x) + return issymmetric(x), issymmetric_pullback +end diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl new file mode 100644 index 000000000..eb002ecfe --- /dev/null +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -0,0 +1,15 @@ +using SparseArrays + +function rrule(::Type{<:SparseMatrixCSC{T,N}}, arr) where {T,N} + function SparseMatrix_pullback(Δ) + return NO_FIELDS, collect(Δ) + end + return SparseMatrixCSC{T,N}(arr), SparseMatrix_pullback +end + +function rrule(::typeof(Matrix), x::SparseMatrixCSC) + function Matrix_pullback(Δ) + NO_FIELDS, Δ + end + return Matrix(x), Matrix_pullback +end diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl new file mode 100644 index 000000000..52161a511 --- /dev/null +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -0,0 +1,8 @@ +using SparseArrays + +@testset "Sparse" begin + r = sparse(rand(3,3)) + x, x̄ = rand(3,3), rand(3,3) + test_rrule(SparseMatrixCSC, r) + test_rrule(Matrix, r) +end diff --git a/test/runtests.jl b/test/runtests.jl index 740d87bca..0bffb7ef9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using LinearAlgebra.BLAS using LinearAlgebra: dot using Random using Statistics +using SparseArrays using Test Random.seed!(1) # Set seed that all testsets should reset to. @@ -48,6 +49,12 @@ println("Testing ChainRules.jl") print(" ") + @testset "SparseArrays" begin + include(joinpath("rulesets", "SparseArrays", "sparsematrix.jl")) + end + + print(" ") + @testset "packages" begin include(joinpath("rulesets", "packages", "NaNMath.jl")) include(joinpath("rulesets", "packages", "SpecialFunctions.jl"))