Skip to content

Commit

Permalink
add sparse rrule (#579)
Browse files Browse the repository at this point in the history
* add sparse(I, J, V, m, n, +) rrule

* cleanup

* fix test

* sparse(A) and sparse(v)

* SparseMatrixCSC and SparseVector

* cleanup

Co-authored-by: Michael Abbott <[email protected]>
  • Loading branch information
CarloLucibello and mcabbott authored Jan 31, 2022
1 parent 970fce4 commit 8108a77
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Expand Down
3 changes: 3 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using LinearAlgebra
using LinearAlgebra.BLAS
using Random
using RealDot: realdot
using SparseArrays
using Statistics

# Basically everything this package does is overloading these, so we make an exception
Expand Down Expand Up @@ -43,6 +44,8 @@ include("rulesets/LinearAlgebra/symmetric.jl")
include("rulesets/LinearAlgebra/factorization.jl")
include("rulesets/LinearAlgebra/uniformscaling.jl")

include("rulesets/SparseArrays/sparsematrix.jl")

include("rulesets/Random/random.jl")

end # module
25 changes: 25 additions & 0 deletions src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
function rrule(::typeof(sparse), I::AbstractVector, J::AbstractVector, V::AbstractVector, m, n, combine::typeof(+))
project_V = ProjectTo(V)

function sparse_pullback(Ω̄)
ΔΩ = unthunk(Ω̄)
ΔV = project_V(ΔΩ[I .+ m .* (J .- 1)])
return NoTangent(), NoTangent(), NoTangent(), ΔV, NoTangent(), NoTangent(), NoTangent()
end

return sparse(I, J, V, m, n, combine), sparse_pullback
end

function rrule(::Type{T}, A::AbstractMatrix) where T <: SparseMatrixCSC
function sparse_pullback(Ω̄)
return NoTangent(), Ω̄
end
return T(A), sparse_pullback
end

function rrule(::Type{T}, v::AbstractVector) where T <: SparseVector
function sparse_pullback(Ω̄)
return NoTangent(), Ω̄
end
return T(v), sparse_pullback
end
19 changes: 19 additions & 0 deletions test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

@testset "sparse(I, J, V, m, n, +)" begin
m, n = 3, 5
s, t, w = [1,2], [2,3], [0.5,0.5]

test_rrule(sparse, s, t, w, m, n, +)
end

@testset "SparseMatrixCSC(A)" begin
A = rand(5, 3)
test_rrule(SparseMatrixCSC, A)
test_rrule(SparseMatrixCSC{Float32,Int}, A, rtol=1e-5)
end

@testset "SparseVector(v)" begin
v = rand(5)
test_rrule(SparseVector, v)
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-5)
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using LinearAlgebra
using LinearAlgebra.BLAS
using LinearAlgebra: dot
using Random
using SparseArrays
using StaticArrays
using Statistics
using Test
Expand Down Expand Up @@ -75,6 +76,10 @@ end

println()

include_test("rulesets/SparseArrays/sparsematrix.jl")

println()

include_test("rulesets/Random/random.jl")
println()
end

2 comments on commit 8108a77

@mcabbott
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/53561

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.26.0 -m "<description of version>" 8108a77a96af5d4b0c460aac393e44f8943f3c5e
git push origin v1.26.0

Please sign in to comment.