diff --git a/Project.toml b/Project.toml index 0b05d02a..5bc3e979 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,10 @@ version = "0.4.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -15,8 +18,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -SSGraphBLAS_jll = "5.1.2" +CEnum = "0.4" +ContextVariablesX = "0.1" +MacroTools = "0.5" +SSGraphBLAS_jll = "5.1" julia = "1.6" -CEnum = "0.4.1" -ContextVariablesX = "0.1.1" -MacroTools = "0.5.6" +ChainRulesCore = "0.10" +ChainRulesTestUtils = "0.7" +FiniteDifferences = "0.12" diff --git a/src/SuiteSparseGraphBLAS.jl b/src/SuiteSparseGraphBLAS.jl index f8727ac8..c3800558 100644 --- a/src/SuiteSparseGraphBLAS.jl +++ b/src/SuiteSparseGraphBLAS.jl @@ -87,9 +87,14 @@ include("operations/kronecker.jl") include("print.jl") include("import.jl") include("export.jl") - -#EXPERIMENTAL include("options.jl") +#EXPERIMENTAL +include("chainrules/chainruleutils.jl") +include("chainrules/mulrules.jl") +include("chainrules/ewiserules.jl") +include("chainrules/maprules.jl") +include("chainrules/reducerules.jl") +include("chainrules/selectrules.jl") #include("random.jl") include("misc.jl") export libgb diff --git a/src/chainrules/chainruleutils.jl b/src/chainrules/chainruleutils.jl new file mode 100644 index 00000000..e09af671 --- /dev/null +++ b/src/chainrules/chainruleutils.jl @@ -0,0 +1,46 @@ +import FiniteDifferences +import LinearAlgebra +import ChainRulesCore: frule, rrule +using ChainRulesCore +const RealOrComplex = Union{Real, Complex} + +#Required for ChainRulesTestUtils +function FiniteDifferences.to_vec(M::GBMatrix) + I, J, X = findnz(M) + function backtomat(xvec) + return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2)) + end + return X, backtomat +end + +function FiniteDifferences.to_vec(v::GBVector) + i, x = findnz(v) + function backtovec(xvec) + return GBVector(i, xvec; nrows=size(v, 1)) + end + return x, backtovec +end + +function FiniteDifferences.rand_tangent( + rng::AbstractRNG, + x::GBMatrix{T} +) where {T <: Union{AbstractFloat, Complex}} + n = nnz(x) + v = rand(rng, -9:0.01:9, n) + I, J, _ = findnz(x) + return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2)) +end + +function FiniteDifferences.rand_tangent( + rng::AbstractRNG, + x::GBVector{T} +) where {T <: Union{AbstractFloat, Complex}} + n = nnz(x) + v = rand(rng, -9:0.01:9, n) + I, _ = findnz(x) + return GBVector(I, v; nrows = size(x, 1)) +end + +FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent() +# LinearAlgebra.norm freaks over the nothings. +LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p) diff --git a/src/chainrules/ewiserules.jl b/src/chainrules/ewiserules.jl new file mode 100644 index 00000000..59877379 --- /dev/null +++ b/src/chainrules/ewiserules.jl @@ -0,0 +1,71 @@ +#emul TIMES +function frule( + (_, ΔA, ΔB, _), + ::typeof(emul), + A::GBArray, + B::GBArray, + ::typeof(BinaryOps.TIMES) +) + Ω = emul(A, B, BinaryOps.TIMES) + ∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES) + return Ω, ∂Ω +end +function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray) + return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES) +end + +function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES)) + function timespullback(ΔΩ) + ∂A = emul(ΔΩ, B) + ∂B = emul(ΔΩ, A) + return NoTangent(), ∂A, ∂B, NoTangent() + end + return emul(A, B, BinaryOps.TIMES), timespullback +end + +function rrule(::typeof(emul), A::GBArray, B::GBArray) + Ω, fullpb = rrule(emul, A, B, BinaryOps.TIMES) + emulpb(ΔΩ) = fullpb(ΔΩ)[1:3] + return Ω, emulpb +end + +############ +# eadd rules +############ + +# PLUS +###### + +function frule( + (_, ΔA, ΔB, _), + ::typeof(eadd), + A::GBArray, + B::GBArray, + ::typeof(BinaryOps.PLUS) +) + Ω = eadd(A, B, BinaryOps.PLUS) + ∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS) + return Ω, ∂Ω +end +function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray) + return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS) +end + +function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS)) + function pluspullback(ΔΩ) + return ( + NoTangent(), + mask(ΔΩ, A; structural = true), + mask(ΔΩ, B; structural = true), + NoTangent() + ) + end + return eadd(A, B, BinaryOps.PLUS), pluspullback +end + +# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule. +function rrule(::typeof(eadd), A::GBArray, B::GBArray) + Ω, fullpb = rrule(eadd, A, B, BinaryOps.PLUS) + eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3] + return Ω, eaddpb +end diff --git a/src/chainrules/maprules.jl b/src/chainrules/maprules.jl new file mode 100644 index 00000000..d4c30e3e --- /dev/null +++ b/src/chainrules/maprules.jl @@ -0,0 +1,17 @@ +# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather +# than AbstractOp. +#function rrule(map, f, xs) +# # Rather than 3 maps really want 1 multimap +# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x) +# ys = map(first, ys_and_pullbacks) +# pullbacks = map(last, ys_and_pullbacks) +# function map_pullback(dys) +# _call(f, x) = f(x) +# dfs_and_dxs = map(_call, pullbacks, dys) +# # but in your case you know it will be NoTangent() so can skip +# df = sum(first, dfs_and_dxs) +# dxs = map(last, dfs_and_dxs) +# return NoTangent(), df, dxs +# end +# return ys, map_pullback +#end diff --git a/src/chainrules/mulrules.jl b/src/chainrules/mulrules.jl new file mode 100644 index 00000000..3344f05d --- /dev/null +++ b/src/chainrules/mulrules.jl @@ -0,0 +1,51 @@ +# Standard arithmetic mul: +function frule( + (_, ΔA, ΔB), + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose +) + frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES) +end +function frule( + (_, ΔA, ΔB, _), + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose, + ::typeof(Semirings.PLUS_TIMES) +) + Ω = mul(A, B) + ∂Ω = mul(ΔA, B) + mul(A, ΔB) + return Ω, ∂Ω +end +# Tests will not pass for this. For two reasons. +# First is #25, the output inference is not type stable. +# That's it's own issue. + +# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings. +# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof. + +function rrule( + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose, + ::typeof(Semirings.PLUS_TIMES) +) + function mulpullback(ΔΩ) + ∂A = mul(ΔΩ, B'; mask=A) + ∂B = mul(A', ΔΩ; mask=B) + return NoTangent(), ∂A, ∂B, NoTangent() + end + return mul(A, B), mulpullback +end + + +function rrule( + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose +) + Ω, mulpullback = rrule(mul, A, B, Semirings.PLUS_TIMES) + pullback(ΔΩ) = mulpullback(ΔΩ)[1:3] +return Ω, pullback +end diff --git a/src/chainrules/reducerules.jl b/src/chainrules/reducerules.jl new file mode 100644 index 00000000..e69de29b diff --git a/src/chainrules/selectrules.jl b/src/chainrules/selectrules.jl new file mode 100644 index 00000000..e69de29b diff --git a/src/lib/LibGraphBLAS.jl b/src/lib/LibGraphBLAS.jl index 29615a77..0ccf0d22 100644 --- a/src/lib/LibGraphBLAS.jl +++ b/src/lib/LibGraphBLAS.jl @@ -27,27 +27,27 @@ macro wraperror(code) elseif info == GrB_NO_VALUE return nothing else - if info == GrB_UNINITIALIZED_OBJECT + if info == GrB_UNINITIALIZED_OBJECT throw(UninitializedObjectError) - elseif info == GrB_INVALID_OBJECT + elseif info == GrB_INVALID_OBJECT throw(InvalidObjectError) - elseif info == GrB_NULL_POINTER + elseif info == GrB_NULL_POINTER throw(NullPointerError) - elseif info == GrB_INVALID_VALUE + elseif info == GrB_INVALID_VALUE throw(InvalidValueError) - elseif info == GrB_INVALID_INDEX + elseif info == GrB_INVALID_INDEX throw(InvalidIndexError) - elseif info == GrB_DOMAIN_MISMATCH + elseif info == GrB_DOMAIN_MISMATCH throw(DomainError(nothing, "GraphBLAS Domain Mismatch")) elseif info == GrB_DIMENSION_MISMATCH throw(DimensionMismatch()) - elseif info == GrB_OUTPUT_NOT_EMPTY + elseif info == GrB_OUTPUT_NOT_EMPTY throw(OutputNotEmptyError) - elseif info == GrB_OUT_OF_MEMORY + elseif info == GrB_OUT_OF_MEMORY throw(OutOfMemoryError()) - elseif info == GrB_INSUFFICIENT_SPACE + elseif info == GrB_INSUFFICIENT_SPACE throw(InsufficientSpaceError) - elseif info == GrB_INDEX_OUT_OF_BOUNDS + elseif info == GrB_INDEX_OUT_OF_BOUNDS throw(BoundsError()) elseif info == GrB_PANIC throw(PANIC) @@ -843,7 +843,7 @@ for T ∈ valid_vec nvals = GrB_Vector_nvals(v) I = Vector{GrB_Index}(undef, nvals) X = Vector{$type}(undef, nvals) - nvals = Ref{GrB_Index}() + nvals = Ref{GrB_Index}(nvals) $func(I, X, nvals, v) nvals[] == length(I) == length(X) || throw(DimensionMismatch()) return I .+ 1, X diff --git a/src/matrix.jl b/src/matrix.jl index 490a67cd..4a809e3e 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -18,7 +18,7 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = X[k]. The dup funct to `|` for booleans and `+` for nonbooleans. """ function GBMatrix( - I::Vector, J::Vector, X::Vector{T}; + I::AbstractVector, J::AbstractVector, X::AbstractVector{T}; dup = BinaryOps.PLUS, nrows = maximum(I), ncols = maximum(J) ) where {T} A = GBMatrix{T}(nrows, ncols) @@ -33,14 +33,14 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = x. The resulting matrix is "iso-valued" such that it only stores `x` once rather than once for each index. """ -function GBMatrix(I::Vector, J::Vector, x::T; +function GBMatrix(I::AbstractVector, J::AbstractVector, x::T; nrows = maximum(I), ncols = maximum(J)) where {T} A = GBMatrix{T}(nrows, ncols) build(A, I, J, x) return A end -function build(A::GBMatrix{T}, I::Vector, J::Vector, x::T) where {T} +function build(A::GBMatrix{T}, I::AbstractVector, J::AbstractVector, x::T) where {T} nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build matrix with existing elements")) length(I) == length(J) || DimensionMismatch("I, J and X must have the same length") x = GBScalar(x) @@ -158,7 +158,8 @@ function Base.show(io::IO, ::MIME"text/plain", A::GBMatrix) gxbprint(io, A) end -SparseArrays.nonzeros(A::GBArray) = findnz(A)[3] +SparseArrays.nonzeros(A::GBArray) = findnz(A)[end] + # Indexing functions #################### diff --git a/src/operations/ewise.jl b/src/operations/ewise.jl index 9a8399b0..b30b2807 100644 --- a/src/operations/ewise.jl +++ b/src/operations/ewise.jl @@ -61,7 +61,6 @@ function emul!( desc = nothing ) op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES) - size(w) == size(u) == size(v) || throw(DimensionMismatch()) op = getoperator(op, optype(u, v)) accum = getoperator(accum, eltype(w)) @@ -275,6 +274,13 @@ function eadd( return eadd!(C, A, B, op; mask, accum, desc) end +function Base.:+(A::GBArray, B::GBArray) + eadd(A, B, nothing) +end + +function Base.:-(A::GBArray, B::GBArray) + eadd(A, B, BinaryOps.MINUS) +end #Elementwise Broadcasts ####################### diff --git a/src/operations/mul.jl b/src/operations/mul.jl index 319bd59b..cd258cd9 100644 --- a/src/operations/mul.jl +++ b/src/operations/mul.jl @@ -59,7 +59,6 @@ function LinearAlgebra.mul!( return w end - """ mul(A::GBArray, B::GBArray; kwargs...)::GBArray diff --git a/src/operations/transpose.jl b/src/operations/transpose.jl index eaa66713..f376c995 100644 --- a/src/operations/transpose.jl +++ b/src/operations/transpose.jl @@ -64,6 +64,29 @@ function Base.copy!( return gbtranspose!(C, A.parent; mask, accum, desc) end +""" + mask!(C::GBArray, A::GBArray, mask::GBArray) + +Apply a mask to matrix `A`, storing the results in C. + +""" +function mask!(C::GBArray, A::GBArray, mask::GBArray; structural = false, complement = false) + desc = Descriptors.T0 + structural && (desc = desc + Descriptors.S) + complement && (desc = desc + Descriptors.C) + gbtranspose!(C, A; mask, desc) + return C +end + +""" + mask(A::GBArray, mask::GBArray) + +Apply a mask to matrix `A`. +""" +function mask(A::GBArray, mask::GBArray; structural = false, complement = false) + return mask!(similar(A), A, mask; structural, complement) +end + function Base.copy( A::LinearAlgebra.Transpose{<:Any, <:GBMatrix}; mask = C_NULL, accum = C_NULL, desc::Descriptor = Descriptors.NULL @@ -71,6 +94,17 @@ function Base.copy( return gbtranspose(A.parent; mask, accum, desc) end +function Base.copy(v::LinearAlgebra.Transpose{<:Any, <:GBVector}) + A = GBMatrix{eltype(v)}(size(v, 1), size(v, 2)) + nz = findnz(v.parent) + for i ∈ 1:length(nz[1]) + println(i) + println(nz[1][i], ": ", nz[2][i]) + A[1, nz[1][i]] = nz[2][i] + end + return A +end + function _handletranspose( A::GBArray, desc::Union{Descriptor, Nothing} = nothing, diff --git a/src/vector.jl b/src/vector.jl index b1210cba..46552e01 100644 --- a/src/vector.jl +++ b/src/vector.jl @@ -14,8 +14,8 @@ GBVector{T}(dims::Dims{1}) where {T} = GBVector{T}(dims...) Create a GBVector from a vector of indices `I` and a vector of values `X`. """ -function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS) where {T} - x = GBVector{T}(maximum(I)) +function GBVector(I::AbstractVector, X::AbstractVector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T} + x = GBVector{T}(nrows) build(x, I, X, dup = dup) return x end @@ -27,14 +27,14 @@ Create an nrows length GBVector v such that M[I[k]] = x. The resulting vector is "iso-valued" such that it only stores `x` once rather than once for each index. """ -function GBVector(I::Vector, x::T; +function GBVector(I::AbstractVector, x::T; nrows = maximum(I)) where {T} A = GBVector{T}(nrows) build(A, I, x) return A end -function build(A::GBVector{T}, I::Vector, x::T) where {T} +function build(A::GBVector{T}, I::AbstractVector, x::T) where {T} nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build vector with existing elements")) x = GBScalar(x) diff --git a/test/chainrules/chainrulesutils.jl b/test/chainrules/chainrulesutils.jl new file mode 100644 index 00000000..e6f9a8bd --- /dev/null +++ b/test/chainrules/chainrulesutils.jl @@ -0,0 +1,17 @@ +using FiniteDifferences +function test_to_vec(x::T; check_inferred=true) where {T} + check_inferred && @inferred FiniteDifferences.to_vec(x) + x_vec, back = FiniteDifferences.to_vec(x) + @test x_vec isa Vector + @test all(s -> s isa Real, x_vec) + check_inferred && @inferred back(x_vec) + @test x == back(x_vec) + return nothing +end + +@testset "chainrulesutils" begin + y = GBMatrix(sprand(10, 10, 0.5)) + test_to_vec(y) + v = GBVector(sprand(10, 0.5)) + test_to_vec(v) +end diff --git a/test/chainrules/ewiserules.jl b/test/chainrules/ewiserules.jl new file mode 100644 index 00000000..f9d60905 --- /dev/null +++ b/test/chainrules/ewiserules.jl @@ -0,0 +1,32 @@ +@testset "Elementwise" begin + @testset "Dense" begin + @testset "Arithmetic Semiring" begin + #dense first + Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) + X = GBMatrix(rand(-10.0:0.05:10.0, 10)) + test_frule(eadd, X, Y; check_inferred=false) + test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_rrule(eadd, X, Y; check_inferred=false) + test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_frule(emul, X, Y; check_inferred=false) + test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_rrule(emul, X, Y; check_inferred=false) + test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + end + end + + @testset "Sparse" begin + @testset "Arithmetic Semiring" begin + Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector) + X = GBMatrix(sprand(10, 0.5)) + test_frule(eadd, X, Y; check_inferred=false) + test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_rrule(eadd, X, Y; check_inferred=false) + test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_frule(emul, X, Y; check_inferred=false) + test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_rrule(emul, X, Y; check_inferred=false) + test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + end + end +end diff --git a/test/chainrules/mulrules.jl b/test/chainrules/mulrules.jl new file mode 100644 index 00000000..ebd7e165 --- /dev/null +++ b/test/chainrules/mulrules.jl @@ -0,0 +1,21 @@ +@testset "mul" begin + @testset "Dense" begin + @testset "Arithmetic Semiring" begin + M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10)) + Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) + test_frule(mul, M, Y; check_inferred=false) + test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_rrule(mul, M, Y; check_inferred=false) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + end + end + + @testset "Sparse" begin + M = GBMatrix(sprand(100, 10, 0.25)) + Y = GBMatrix(sprand(10, 0.1)) #using matrix for now until I work out transpose(v::GBVector) + test_frule(mul, M, Y; check_inferred=false) + test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_rrule(mul, M, Y; check_inferred=false) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 0d7f65b0..6b6d2cb4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using SuiteSparseGraphBLAS using SparseArrays using Test using Random - +using ChainRulesTestUtils Random.seed!(1) function include_test(path) @@ -12,6 +12,10 @@ end println("Testing SuiteSparseGraphBLAS.jl") @testset "SuiteSparseGraphBLAS" begin + include_test("gbarray.jl") include_test("operations.jl") + include_test("chainrules/chainrulesutils.jl") + include_test("chainrules/mulrules.jl") + include_test("chainrules/mulrules.jl") end