From 11bdec4669b7b5a0242c63fa090055040c237ca4 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Fri, 2 Jul 2021 23:28:25 -0400 Subject: [PATCH 01/18] arithmetic groundwork --- Project.toml | 7 ++++-- src/SuiteSparseGraphBLAS.jl | 5 ++-- src/chainrules/arithmeticrules.jl | 26 +++++++++++++++++++ src/chainrules/chainruleutils.jl | 42 +++++++++++++++++++++++++++++++ src/operations/ewise.jl | 7 ++++++ 5 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 src/chainrules/arithmeticrules.jl create mode 100644 src/chainrules/chainruleutils.jl diff --git a/Project.toml b/Project.toml index 0b05d02a..7deeab49 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,10 @@ version = "0.4.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -15,8 +18,8 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -SSGraphBLAS_jll = "5.1.2" -julia = "1.6" CEnum = "0.4.1" ContextVariablesX = "0.1.1" MacroTools = "0.5.6" +SSGraphBLAS_jll = "5.1.2" +julia = "1.6" diff --git a/src/SuiteSparseGraphBLAS.jl b/src/SuiteSparseGraphBLAS.jl index 72a1c340..b749c32f 100644 --- a/src/SuiteSparseGraphBLAS.jl +++ b/src/SuiteSparseGraphBLAS.jl @@ -87,9 +87,10 @@ include("operations/kronecker.jl") include("print.jl") include("import.jl") include("export.jl") - -#EXPERIMENTAL include("options.jl") +#EXPERIMENTAL +include("chainrules/chainruleutils.jl") +include("chainrules/arithmeticrules.jl") #include("random.jl") include("misc.jl") export libgb diff --git a/src/chainrules/arithmeticrules.jl b/src/chainrules/arithmeticrules.jl new file mode 100644 index 00000000..7e547f08 --- /dev/null +++ b/src/chainrules/arithmeticrules.jl @@ -0,0 +1,26 @@ +import ChainRulesCore: frule, rrule +using ChainRulesCore +const RealOrComplex = Union{Real, Complex} +function frule( + (_, ΔA, ΔB), + ::typeof(mul), + A::GBArray, + B::GBArray +) + Ω = mul(A, B) + ∂Ω = mul(ΔA, B) + mul(A, ΔB) + return Ω, ∂Ω +end + +function rrule( + ::typeof(mul), + A::GBVecOrMat{T}, + B::GBVecOrMat{T} +) where {T <: RealOrComplex} + function mulpullback(ΔΩ) + ∂A = mul(ΔΩ, B') + ∂B = mul(A', ΔΩ) + return (NoTangent(), ∂A, ∂B) + end + return mul(A, B), mulpullback +end diff --git a/src/chainrules/chainruleutils.jl b/src/chainrules/chainruleutils.jl new file mode 100644 index 00000000..472fd51a --- /dev/null +++ b/src/chainrules/chainruleutils.jl @@ -0,0 +1,42 @@ +import FiniteDifferences: to_vec, rand_tangent +import LinearAlgebra: norm +#Required for ChainRulesTestUtils +function to_vec(M::GBMatrix) + I, J, X = findnz(M) + function backtomat(xvec) + return GBMatrix(I, J, xvec) + end + return X, backtomat +end + +function to_vec(v::GBVector) + i, x = findnz(v) + function backtovec(xvec) + return GBVector(i, xvec) + end + return x, backtovec +end + +function rand_tangent( + rng::AbstractRNG, + x::GBMatrix{T} +) where {T <: Union{AbstractFloat, Complex}} + n = nnz(x) + v = rand(rng, -9:0.01:9, n) + I, J, _ = findnz(x) + return GBMatrix(I, J, v) +end + +function rand_tangent( + rng::AbstractRNG, + x::GBVector{T} +) where {T <: Union{AbstractFloat, Complex}} + n = nnz(x) + v = rand(rng, -9:0.01:9, n) + I, _ = findnz(x) + return GBVector(I, v) +end + + +# LinearAlgebra.norm freaks over the nothings. +norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p) diff --git a/src/operations/ewise.jl b/src/operations/ewise.jl index 8be5e6aa..6bbc089b 100644 --- a/src/operations/ewise.jl +++ b/src/operations/ewise.jl @@ -291,6 +291,13 @@ function eadd( return eadd!(C, A, B, op; mask, accum, desc) end +function Base.:+(A::GBArray, B::GBArray) + eadd(A, B, nothing) +end + +function Base.:-(A::GBArray, B::GBArray) + eadd(A, B, BinaryOps.MINUS) +end #Elementwise Broadcasts ####################### From 42c8670389a9f8491e43268502186a65d2762256 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Mon, 5 Jul 2021 23:54:24 -0400 Subject: [PATCH 02/18] arithmetic rules for mul and elwise 1st pass --- Project.toml | 1 + src/SuiteSparseGraphBLAS.jl | 6 +++- src/chainrules/arithmeticrules.jl | 26 ---------------- src/chainrules/chainruleutils.jl | 28 +++++++++-------- src/chainrules/ewiserules.jl | 51 +++++++++++++++++++++++++++++++ src/chainrules/maprules.jl | 17 +++++++++++ src/chainrules/mulrules.jl | 30 ++++++++++++++++++ src/chainrules/reducerules.jl | 0 src/chainrules/selectrules.jl | 0 src/matrix.jl | 3 +- src/operations/ewise.jl | 1 - src/operations/mul.jl | 1 - src/vector.jl | 4 +-- 13 files changed, 124 insertions(+), 44 deletions(-) delete mode 100644 src/chainrules/arithmeticrules.jl create mode 100644 src/chainrules/ewiserules.jl create mode 100644 src/chainrules/maprules.jl create mode 100644 src/chainrules/mulrules.jl create mode 100644 src/chainrules/reducerules.jl create mode 100644 src/chainrules/selectrules.jl diff --git a/Project.toml b/Project.toml index 7deeab49..ec8014d1 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.4.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5" diff --git a/src/SuiteSparseGraphBLAS.jl b/src/SuiteSparseGraphBLAS.jl index b749c32f..9ac93eb5 100644 --- a/src/SuiteSparseGraphBLAS.jl +++ b/src/SuiteSparseGraphBLAS.jl @@ -90,7 +90,11 @@ include("export.jl") include("options.jl") #EXPERIMENTAL include("chainrules/chainruleutils.jl") -include("chainrules/arithmeticrules.jl") +include("chainrules/mulrules.jl") +include("chainrules/ewiserules.jl") +include("chainrules/maprules.jl") +include("chainrules/reducerules.jl") +include("chainrules/selectrules.jl") #include("random.jl") include("misc.jl") export libgb diff --git a/src/chainrules/arithmeticrules.jl b/src/chainrules/arithmeticrules.jl deleted file mode 100644 index 7e547f08..00000000 --- a/src/chainrules/arithmeticrules.jl +++ /dev/null @@ -1,26 +0,0 @@ -import ChainRulesCore: frule, rrule -using ChainRulesCore -const RealOrComplex = Union{Real, Complex} -function frule( - (_, ΔA, ΔB), - ::typeof(mul), - A::GBArray, - B::GBArray -) - Ω = mul(A, B) - ∂Ω = mul(ΔA, B) + mul(A, ΔB) - return Ω, ∂Ω -end - -function rrule( - ::typeof(mul), - A::GBVecOrMat{T}, - B::GBVecOrMat{T} -) where {T <: RealOrComplex} - function mulpullback(ΔΩ) - ∂A = mul(ΔΩ, B') - ∂B = mul(A', ΔΩ) - return (NoTangent(), ∂A, ∂B) - end - return mul(A, B), mulpullback -end diff --git a/src/chainrules/chainruleutils.jl b/src/chainrules/chainruleutils.jl index 472fd51a..e09af671 100644 --- a/src/chainrules/chainruleutils.jl +++ b/src/chainrules/chainruleutils.jl @@ -1,42 +1,46 @@ -import FiniteDifferences: to_vec, rand_tangent -import LinearAlgebra: norm +import FiniteDifferences +import LinearAlgebra +import ChainRulesCore: frule, rrule +using ChainRulesCore +const RealOrComplex = Union{Real, Complex} + #Required for ChainRulesTestUtils -function to_vec(M::GBMatrix) +function FiniteDifferences.to_vec(M::GBMatrix) I, J, X = findnz(M) function backtomat(xvec) - return GBMatrix(I, J, xvec) + return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2)) end return X, backtomat end -function to_vec(v::GBVector) +function FiniteDifferences.to_vec(v::GBVector) i, x = findnz(v) function backtovec(xvec) - return GBVector(i, xvec) + return GBVector(i, xvec; nrows=size(v, 1)) end return x, backtovec end -function rand_tangent( +function FiniteDifferences.rand_tangent( rng::AbstractRNG, x::GBMatrix{T} ) where {T <: Union{AbstractFloat, Complex}} n = nnz(x) v = rand(rng, -9:0.01:9, n) I, J, _ = findnz(x) - return GBMatrix(I, J, v) + return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2)) end -function rand_tangent( +function FiniteDifferences.rand_tangent( rng::AbstractRNG, x::GBVector{T} ) where {T <: Union{AbstractFloat, Complex}} n = nnz(x) v = rand(rng, -9:0.01:9, n) I, _ = findnz(x) - return GBVector(I, v) + return GBVector(I, v; nrows = size(x, 1)) end - +FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent() # LinearAlgebra.norm freaks over the nothings. -norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p) +LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p) diff --git a/src/chainrules/ewiserules.jl b/src/chainrules/ewiserules.jl new file mode 100644 index 00000000..470dd894 --- /dev/null +++ b/src/chainrules/ewiserules.jl @@ -0,0 +1,51 @@ +#emul TIMES +function frule( + (_, ΔA, ΔB, _), + ::typeof(emul), + A::GBArray, + B::GBArray, + ::typeof(BinaryOps.TIMES) +) + Ω = emul(A, B, BinaryOps.TIMES) + ∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES) + return Ω, ∂Ω +end +function frule((_, ΔA, ΔB, _), ::typeof(emul), A::GBArray, B::GBArray) + return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES) +end + +function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES)) + function timespullback(ΔΩ) + ∂A = emul(ΔΩ, B) + ∂B = emul(ΔΩ, A) + return NoTangent(), ∂A, ∂B, NoTangent() + end + return emul(A, B, BinaryOps.TIMES), timespullback +end + +function rrule(::typeof(emul), A::GBArray, B::GBArray) + return rrule(emul, A, B, BinaryOps.TIMES) +end + +#eadd PLUS +function frule( + (_, ΔA, ΔB, _), + ::typeof(eadd), + A::GBArray, + B::GBArray, + ::typeof(BinaryOps.PLUS) +) + Ω = eadd(A, B, BinaryOps.PLUS) + ∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS) + return Ω, ∂Ω +end +function frule((_, ΔA, ΔB, _), ::typeof(eadd), A::GBArray, B::GBArray) + return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS) +end + +function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS)) + function pluspullback(ΔΩ) + return NoTangent(), ΔΩ, ΔΩ, NoTangent() + end + return eadd(A, B, BinaryOps.PLUS), pluspullback +end diff --git a/src/chainrules/maprules.jl b/src/chainrules/maprules.jl new file mode 100644 index 00000000..d4c30e3e --- /dev/null +++ b/src/chainrules/maprules.jl @@ -0,0 +1,17 @@ +# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather +# than AbstractOp. +#function rrule(map, f, xs) +# # Rather than 3 maps really want 1 multimap +# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x) +# ys = map(first, ys_and_pullbacks) +# pullbacks = map(last, ys_and_pullbacks) +# function map_pullback(dys) +# _call(f, x) = f(x) +# dfs_and_dxs = map(_call, pullbacks, dys) +# # but in your case you know it will be NoTangent() so can skip +# df = sum(first, dfs_and_dxs) +# dxs = map(last, dfs_and_dxs) +# return NoTangent(), df, dxs +# end +# return ys, map_pullback +#end diff --git a/src/chainrules/mulrules.jl b/src/chainrules/mulrules.jl new file mode 100644 index 00000000..4c358949 --- /dev/null +++ b/src/chainrules/mulrules.jl @@ -0,0 +1,30 @@ +# Standard arithmetic mul: +function frule( + (_, ΔA, ΔB), + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose +) + Ω = mul(A, B) + ∂Ω = mul(ΔA, B) + mul(A, ΔB) + return Ω, ∂Ω +end + +# Tests will not pass for this. For two reasons. +# First is #25, the output inference is not type stable. +# That's it's own issue. + +# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings. +# This means they give different answers. FiniteDifferences is probably "incorrect". +function rrule( + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose +) + function mulpullback(ΔΩ) + ∂A = mul(ΔΩ, B') + ∂B = mul(A', ΔΩ) + return NoTangent(), ∂A, ∂B + end + return mul(A, B), mulpullback +end diff --git a/src/chainrules/reducerules.jl b/src/chainrules/reducerules.jl new file mode 100644 index 00000000..e69de29b diff --git a/src/chainrules/selectrules.jl b/src/chainrules/selectrules.jl new file mode 100644 index 00000000..e69de29b diff --git a/src/matrix.jl b/src/matrix.jl index 490a67cd..648992d0 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -158,7 +158,8 @@ function Base.show(io::IO, ::MIME"text/plain", A::GBMatrix) gxbprint(io, A) end -SparseArrays.nonzeros(A::GBArray) = findnz(A)[3] +SparseArrays.nonzeros(A::GBArray) = findnz(A)[end] + # Indexing functions #################### diff --git a/src/operations/ewise.jl b/src/operations/ewise.jl index 6bbc089b..2af071a9 100644 --- a/src/operations/ewise.jl +++ b/src/operations/ewise.jl @@ -61,7 +61,6 @@ function emul!( desc = nothing ) op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES) - size(w) == size(u) == size(v) || throw(DimensionMismatch()) op = getoperator(op, optype(u, v)) accum = getoperator(accum, eltype(w)) diff --git a/src/operations/mul.jl b/src/operations/mul.jl index 30815f84..04c74ee6 100644 --- a/src/operations/mul.jl +++ b/src/operations/mul.jl @@ -59,7 +59,6 @@ function LinearAlgebra.mul!( return w end - """ mul(A::GBArray, B::GBArray; kwargs...)::GBArray diff --git a/src/vector.jl b/src/vector.jl index b1210cba..e2b9fbcb 100644 --- a/src/vector.jl +++ b/src/vector.jl @@ -14,8 +14,8 @@ GBVector{T}(dims::Dims{1}) where {T} = GBVector{T}(dims...) Create a GBVector from a vector of indices `I` and a vector of values `X`. """ -function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS) where {T} - x = GBVector{T}(maximum(I)) +function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T} + x = GBVector{T}(nrows) build(x, I, X, dup = dup) return x end From 8f595b9ffa36dca2c643d218b9fc4dfb106b679a Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Tue, 6 Jul 2021 02:00:58 -0400 Subject: [PATCH 03/18] tests and a few fixes --- src/chainrules/ewiserules.jl | 19 ++++++++++++++--- src/chainrules/mulrules.jl | 30 ++++++++++++++++++++++++--- src/lib/LibGraphBLAS.jl | 22 ++++++++++---------- test/runtests.jl | 3 ++- test/testrules.jl | 40 ++++++++++++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 18 deletions(-) create mode 100644 test/testrules.jl diff --git a/src/chainrules/ewiserules.jl b/src/chainrules/ewiserules.jl index 470dd894..e52bb4f3 100644 --- a/src/chainrules/ewiserules.jl +++ b/src/chainrules/ewiserules.jl @@ -10,7 +10,7 @@ function frule( ∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES) return Ω, ∂Ω end -function frule((_, ΔA, ΔB, _), ::typeof(emul), A::GBArray, B::GBArray) +function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray) return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES) end @@ -24,7 +24,12 @@ function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES) end function rrule(::typeof(emul), A::GBArray, B::GBArray) - return rrule(emul, A, B, BinaryOps.TIMES) + function timespullback(ΔΩ) + ∂A = emul(ΔΩ, B) + ∂B = emul(ΔΩ, A) + return NoTangent(), ∂A, ∂B + end + return emul(A, B, BinaryOps.TIMES), timespullback end #eadd PLUS @@ -39,7 +44,7 @@ function frule( ∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS) return Ω, ∂Ω end -function frule((_, ΔA, ΔB, _), ::typeof(eadd), A::GBArray, B::GBArray) +function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray) return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS) end @@ -49,3 +54,11 @@ function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS)) end return eadd(A, B, BinaryOps.PLUS), pluspullback end + +# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule. +function rrule(::typeof(eadd), A::GBArray, B::GBArray) + function pluspullback(ΔΩ) + return NoTangent(), ΔΩ, ΔΩ + end + return eadd(A, B, BinaryOps.PLUS), pluspullback +end diff --git a/src/chainrules/mulrules.jl b/src/chainrules/mulrules.jl index 4c358949..a7b4bc3e 100644 --- a/src/chainrules/mulrules.jl +++ b/src/chainrules/mulrules.jl @@ -4,18 +4,42 @@ function frule( ::typeof(mul), A::GBMatOrTranspose, B::GBMatOrTranspose +) + frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES) +end +function frule( + (_, ΔA, ΔB, _), + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose, + ::typeof(Semirings.PLUS_TIMES) ) Ω = mul(A, B) ∂Ω = mul(ΔA, B) + mul(A, ΔB) return Ω, ∂Ω end - # Tests will not pass for this. For two reasons. # First is #25, the output inference is not type stable. # That's it's own issue. # Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings. -# This means they give different answers. FiniteDifferences is probably "incorrect". +# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof. + +function rrule( + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose, + ::typeof(Semirings.PLUS_TIMES) +) + function mulpullback(ΔΩ) + ∂A = mul(ΔΩ, B') + ∂B = mul(A', ΔΩ) + return NoTangent(), ∂A, ∂B, NoTangent() + end + return mul(A, B), mulpullback +end + +# Do I have to duplicate this? :/ function rrule( ::typeof(mul), A::GBMatOrTranspose, @@ -26,5 +50,5 @@ function rrule( ∂B = mul(A', ΔΩ) return NoTangent(), ∂A, ∂B end - return mul(A, B), mulpullback +return mul(A, B), mulpullback end diff --git a/src/lib/LibGraphBLAS.jl b/src/lib/LibGraphBLAS.jl index 29615a77..0ccf0d22 100644 --- a/src/lib/LibGraphBLAS.jl +++ b/src/lib/LibGraphBLAS.jl @@ -27,27 +27,27 @@ macro wraperror(code) elseif info == GrB_NO_VALUE return nothing else - if info == GrB_UNINITIALIZED_OBJECT + if info == GrB_UNINITIALIZED_OBJECT throw(UninitializedObjectError) - elseif info == GrB_INVALID_OBJECT + elseif info == GrB_INVALID_OBJECT throw(InvalidObjectError) - elseif info == GrB_NULL_POINTER + elseif info == GrB_NULL_POINTER throw(NullPointerError) - elseif info == GrB_INVALID_VALUE + elseif info == GrB_INVALID_VALUE throw(InvalidValueError) - elseif info == GrB_INVALID_INDEX + elseif info == GrB_INVALID_INDEX throw(InvalidIndexError) - elseif info == GrB_DOMAIN_MISMATCH + elseif info == GrB_DOMAIN_MISMATCH throw(DomainError(nothing, "GraphBLAS Domain Mismatch")) elseif info == GrB_DIMENSION_MISMATCH throw(DimensionMismatch()) - elseif info == GrB_OUTPUT_NOT_EMPTY + elseif info == GrB_OUTPUT_NOT_EMPTY throw(OutputNotEmptyError) - elseif info == GrB_OUT_OF_MEMORY + elseif info == GrB_OUT_OF_MEMORY throw(OutOfMemoryError()) - elseif info == GrB_INSUFFICIENT_SPACE + elseif info == GrB_INSUFFICIENT_SPACE throw(InsufficientSpaceError) - elseif info == GrB_INDEX_OUT_OF_BOUNDS + elseif info == GrB_INDEX_OUT_OF_BOUNDS throw(BoundsError()) elseif info == GrB_PANIC throw(PANIC) @@ -843,7 +843,7 @@ for T ∈ valid_vec nvals = GrB_Vector_nvals(v) I = Vector{GrB_Index}(undef, nvals) X = Vector{$type}(undef, nvals) - nvals = Ref{GrB_Index}() + nvals = Ref{GrB_Index}(nvals) $func(I, X, nvals, v) nvals[] == length(I) == length(X) || throw(DimensionMismatch()) return I .+ 1, X diff --git a/test/runtests.jl b/test/runtests.jl index 0d7f65b0..dae1a95d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using SuiteSparseGraphBLAS using SparseArrays using Test using Random - +using ChainRulesTestUtils Random.seed!(1) function include_test(path) @@ -14,4 +14,5 @@ println("Testing SuiteSparseGraphBLAS.jl") @testset "SuiteSparseGraphBLAS" begin include_test("gbarray.jl") include_test("operations.jl") + include_test("testrules.jl") end diff --git a/test/testrules.jl b/test/testrules.jl new file mode 100644 index 00000000..a638f628 --- /dev/null +++ b/test/testrules.jl @@ -0,0 +1,40 @@ +@testset "Dense" begin + @testset "arithmetic semiring" begin + #dense first + M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10)) + Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) + test_frule(mul, M, Y; check_inferred=false) + test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_rrule(mul, M, Y; check_inferred=false) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + X = GBMatrix(rand(-10.0:0.05:10.0, 10)) + test_frule(eadd, X, Y; check_inferred=false) + test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_rrule(eadd, X, Y; check_inferred=false) + test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_frule(emul, X, Y; check_inferred=false) + test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_rrule(emul, X, Y; check_inferred=false) + test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + end +end + +@testset "Sparse" begin + @testset "arithmetic semiring" begin + M = GBMatrix(sprand(10, 10, 0.5)) + Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector) + test_frule(mul, M, Y; check_inferred=false) + test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_rrule(mul, M, Y; check_inferred=false) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + X = GBMatrix(sprand(10, 0.5)) + test_frule(eadd, X, Y; check_inferred=false) + test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_rrule(eadd, X, Y; check_inferred=false) + test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_frule(emul, X, Y; check_inferred=false) + test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_rrule(emul, X, Y; check_inferred=false) + test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + end +end From c58a4a835919aa8f52da7e1a5c8f20f8fb6829f1 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Fri, 9 Jul 2021 11:28:36 -0400 Subject: [PATCH 04/18] Add mask function, fix eadd(PLUS) --- src/chainrules/ewiserules.jl | 31 +++++++++++++++++++------------ src/operations/transpose.jl | 16 ++++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/chainrules/ewiserules.jl b/src/chainrules/ewiserules.jl index e52bb4f3..59877379 100644 --- a/src/chainrules/ewiserules.jl +++ b/src/chainrules/ewiserules.jl @@ -24,15 +24,18 @@ function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES) end function rrule(::typeof(emul), A::GBArray, B::GBArray) - function timespullback(ΔΩ) - ∂A = emul(ΔΩ, B) - ∂B = emul(ΔΩ, A) - return NoTangent(), ∂A, ∂B - end - return emul(A, B, BinaryOps.TIMES), timespullback + Ω, fullpb = rrule(emul, A, B, BinaryOps.TIMES) + emulpb(ΔΩ) = fullpb(ΔΩ)[1:3] + return Ω, emulpb end -#eadd PLUS +############ +# eadd rules +############ + +# PLUS +###### + function frule( (_, ΔA, ΔB, _), ::typeof(eadd), @@ -50,15 +53,19 @@ end function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS)) function pluspullback(ΔΩ) - return NoTangent(), ΔΩ, ΔΩ, NoTangent() + return ( + NoTangent(), + mask(ΔΩ, A; structural = true), + mask(ΔΩ, B; structural = true), + NoTangent() + ) end return eadd(A, B, BinaryOps.PLUS), pluspullback end # Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule. function rrule(::typeof(eadd), A::GBArray, B::GBArray) - function pluspullback(ΔΩ) - return NoTangent(), ΔΩ, ΔΩ - end - return eadd(A, B, BinaryOps.PLUS), pluspullback + Ω, fullpb = rrule(eadd, A, B, BinaryOps.PLUS) + eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3] + return Ω, eaddpb end diff --git a/src/operations/transpose.jl b/src/operations/transpose.jl index eaa66713..ef39449a 100644 --- a/src/operations/transpose.jl +++ b/src/operations/transpose.jl @@ -64,6 +64,22 @@ function Base.copy!( return gbtranspose!(C, A.parent; mask, accum, desc) end +""" + mask!(C::GBArray, A::GBArray, mask::GBArray) + mask(A::GBArray, mask::GBArray) + + +""" +function mask!(C::GBArray, A::GBArray, mask::GBArray; structural = false, complement = false) + desc = Descriptors.T0 + structural && (desc = desc + Descriptors.S) + complement && (desc = desc + Descriptors.C) + gbtranspose!(C, A; mask, desc) + return C +end +function mask(A::GBArray, mask::GBArray; structural = false, complement = false) + return mask!(similar(A), A, mask; structural, complement) +end function Base.copy( A::LinearAlgebra.Transpose{<:Any, <:GBMatrix}; mask = C_NULL, accum = C_NULL, desc::Descriptor = Descriptors.NULL From 980d7d5e49329c7a20da94861e67bbca8b2f0a9a Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Fri, 9 Jul 2021 11:52:06 -0400 Subject: [PATCH 05/18] correct mul rrules --- src/chainrules/mulrules.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/chainrules/mulrules.jl b/src/chainrules/mulrules.jl index a7b4bc3e..3344f05d 100644 --- a/src/chainrules/mulrules.jl +++ b/src/chainrules/mulrules.jl @@ -32,23 +32,20 @@ function rrule( ::typeof(Semirings.PLUS_TIMES) ) function mulpullback(ΔΩ) - ∂A = mul(ΔΩ, B') - ∂B = mul(A', ΔΩ) + ∂A = mul(ΔΩ, B'; mask=A) + ∂B = mul(A', ΔΩ; mask=B) return NoTangent(), ∂A, ∂B, NoTangent() end return mul(A, B), mulpullback end -# Do I have to duplicate this? :/ + function rrule( ::typeof(mul), A::GBMatOrTranspose, B::GBMatOrTranspose ) - function mulpullback(ΔΩ) - ∂A = mul(ΔΩ, B') - ∂B = mul(A', ΔΩ) - return NoTangent(), ∂A, ∂B - end -return mul(A, B), mulpullback + Ω, mulpullback = rrule(mul, A, B, Semirings.PLUS_TIMES) + pullback(ΔΩ) = mulpullback(ΔΩ)[1:3] +return Ω, pullback end From 4b2e00c38d4fd7eff33964090c793df51a93830c Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Fri, 9 Jul 2021 20:27:26 -0400 Subject: [PATCH 06/18] test folder structure --- test/chainrules/ewiserules.jl | 32 ++++++++++++++++++++++++++++ test/chainrules/mulrules.jl | 21 ++++++++++++++++++ test/runtests.jl | 3 ++- test/testrules.jl | 40 ----------------------------------- 4 files changed, 55 insertions(+), 41 deletions(-) create mode 100644 test/chainrules/ewiserules.jl create mode 100644 test/chainrules/mulrules.jl delete mode 100644 test/testrules.jl diff --git a/test/chainrules/ewiserules.jl b/test/chainrules/ewiserules.jl new file mode 100644 index 00000000..f9d60905 --- /dev/null +++ b/test/chainrules/ewiserules.jl @@ -0,0 +1,32 @@ +@testset "Elementwise" begin + @testset "Dense" begin + @testset "Arithmetic Semiring" begin + #dense first + Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) + X = GBMatrix(rand(-10.0:0.05:10.0, 10)) + test_frule(eadd, X, Y; check_inferred=false) + test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_rrule(eadd, X, Y; check_inferred=false) + test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_frule(emul, X, Y; check_inferred=false) + test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_rrule(emul, X, Y; check_inferred=false) + test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + end + end + + @testset "Sparse" begin + @testset "Arithmetic Semiring" begin + Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector) + X = GBMatrix(sprand(10, 0.5)) + test_frule(eadd, X, Y; check_inferred=false) + test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_rrule(eadd, X, Y; check_inferred=false) + test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_frule(emul, X, Y; check_inferred=false) + test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_rrule(emul, X, Y; check_inferred=false) + test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + end + end +end diff --git a/test/chainrules/mulrules.jl b/test/chainrules/mulrules.jl new file mode 100644 index 00000000..ebd7e165 --- /dev/null +++ b/test/chainrules/mulrules.jl @@ -0,0 +1,21 @@ +@testset "mul" begin + @testset "Dense" begin + @testset "Arithmetic Semiring" begin + M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10)) + Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) + test_frule(mul, M, Y; check_inferred=false) + test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_rrule(mul, M, Y; check_inferred=false) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + end + end + + @testset "Sparse" begin + M = GBMatrix(sprand(100, 10, 0.25)) + Y = GBMatrix(sprand(10, 0.1)) #using matrix for now until I work out transpose(v::GBVector) + test_frule(mul, M, Y; check_inferred=false) + test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_rrule(mul, M, Y; check_inferred=false) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index dae1a95d..55726335 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,5 +14,6 @@ println("Testing SuiteSparseGraphBLAS.jl") @testset "SuiteSparseGraphBLAS" begin include_test("gbarray.jl") include_test("operations.jl") - include_test("testrules.jl") + include_test("chainrules/mulrules.jl") + include_test("chainrules/mulrules.jl") end diff --git a/test/testrules.jl b/test/testrules.jl deleted file mode 100644 index a638f628..00000000 --- a/test/testrules.jl +++ /dev/null @@ -1,40 +0,0 @@ -@testset "Dense" begin - @testset "arithmetic semiring" begin - #dense first - M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10)) - Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) - test_frule(mul, M, Y; check_inferred=false) - test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) - test_rrule(mul, M, Y; check_inferred=false) - test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) - X = GBMatrix(rand(-10.0:0.05:10.0, 10)) - test_frule(eadd, X, Y; check_inferred=false) - test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_rrule(eadd, X, Y; check_inferred=false) - test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_frule(emul, X, Y; check_inferred=false) - test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) - test_rrule(emul, X, Y; check_inferred=false) - test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) - end -end - -@testset "Sparse" begin - @testset "arithmetic semiring" begin - M = GBMatrix(sprand(10, 10, 0.5)) - Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector) - test_frule(mul, M, Y; check_inferred=false) - test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) - test_rrule(mul, M, Y; check_inferred=false) - test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) - X = GBMatrix(sprand(10, 0.5)) - test_frule(eadd, X, Y; check_inferred=false) - test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_rrule(eadd, X, Y; check_inferred=false) - test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_frule(emul, X, Y; check_inferred=false) - test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) - test_rrule(emul, X, Y; check_inferred=false) - test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) - end -end From 952e7a044beeb5be95b4dea495f66e06da88fea9 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Sat, 10 Jul 2021 08:30:23 -0400 Subject: [PATCH 07/18] mask and vector transpose v1 --- src/operations/transpose.jl | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/operations/transpose.jl b/src/operations/transpose.jl index eaa66713..041c3fb2 100644 --- a/src/operations/transpose.jl +++ b/src/operations/transpose.jl @@ -64,6 +64,28 @@ function Base.copy!( return gbtranspose!(C, A.parent; mask, accum, desc) end +""" + mask!(C::GBArray, A::GBArray, mask::GBArray) + +Apply a mask to matrix `A`, storing the results in C. +""" +function mask!(C::GBArray, A::GBArray, mask::GBArray; structural = false, complement = false) + desc = Descriptors.T0 + structural && (desc = desc + Descriptors.S) + complement && (desc = desc + Descriptors.C) + gbtranspose!(C, A; mask, desc) + return C +end + +""" + mask(A::GBArray, mask::GBArray) + +Apply a mask to matrix `A`. +""" +function mask(A::GBArray, mask::GBArray; structural = false, complement = false) + return mask!(similar(A), A, mask; structural, complement) +end + function Base.copy( A::LinearAlgebra.Transpose{<:Any, <:GBMatrix}; mask = C_NULL, accum = C_NULL, desc::Descriptor = Descriptors.NULL @@ -71,6 +93,17 @@ function Base.copy( return gbtranspose(A.parent; mask, accum, desc) end +function Base.copy(v::LinearAlgebra.Transpose{<:Any, <:GBVector}) + A = GBMatrix{eltype(v)}(size(v, 1), size(v, 2)) + nz = findnz(v.parent) + for i ∈ 1:length(nz[1]) + println(i) + println(nz[1][i], ": ", nz[2][i]) + A[1, nz[1][i]] = nz[2][i] + end + return A +end + function _handletranspose( A::GBArray, desc::Union{Descriptor, Nothing} = nothing, From b2289bf76ef0435c7dabbfa596ac3c72ed1fec5b Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Sat, 10 Jul 2021 11:30:28 -0400 Subject: [PATCH 08/18] Broken constructor rules --- src/SuiteSparseGraphBLAS.jl | 1 + src/chainrules/constructorrules.jl | 18 ++++++++++++++++++ src/matrix.jl | 6 +++--- src/vector.jl | 6 +++--- test/chainrules/chainrulesutils.jl | 17 +++++++++++++++++ test/chainrules/constructorrules.jl | 10 ++++++++++ test/runtests.jl | 3 +++ 7 files changed, 55 insertions(+), 6 deletions(-) create mode 100644 src/chainrules/constructorrules.jl create mode 100644 test/chainrules/chainrulesutils.jl create mode 100644 test/chainrules/constructorrules.jl diff --git a/src/SuiteSparseGraphBLAS.jl b/src/SuiteSparseGraphBLAS.jl index c3800558..f6c023dc 100644 --- a/src/SuiteSparseGraphBLAS.jl +++ b/src/SuiteSparseGraphBLAS.jl @@ -95,6 +95,7 @@ include("chainrules/ewiserules.jl") include("chainrules/maprules.jl") include("chainrules/reducerules.jl") include("chainrules/selectrules.jl") +include("chainrules/constructorrules.jl") #include("random.jl") include("misc.jl") export libgb diff --git a/src/chainrules/constructorrules.jl b/src/chainrules/constructorrules.jl new file mode 100644 index 00000000..7adce77c --- /dev/null +++ b/src/chainrules/constructorrules.jl @@ -0,0 +1,18 @@ +function frule( + (_, _, _, ΔX), + ::Type{GBMatrix}, + I::AbstractVector{<:Integer}, + J::AbstractVector{<:Integer}, + X::AbstractVector +) + return GBMatrix(I, J, X), GBMatrix(I, J, ΔX) +end +function rrule(::Type{GBMatrix}, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, X::AbstractVector) + gbmatrix_pullback(Δgbmatrix) = NoTangent(), NoTangent(), NoTangent(), nonzeros(Δgbmatrix) + return GBMatrix(I, J, X), gbmatrix_pullback +end + +function rrule(::Type{GBVector}, I::AbstractVector{<:Integer}, X::AbstractVector) + gbvector_pullback(Δgbvector) = NoTangent(), NoTangent(), nonzeros(Δgbvector) + return GBVector(I, X), gbvector_pullback +end diff --git a/src/matrix.jl b/src/matrix.jl index 648992d0..4a809e3e 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -18,7 +18,7 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = X[k]. The dup funct to `|` for booleans and `+` for nonbooleans. """ function GBMatrix( - I::Vector, J::Vector, X::Vector{T}; + I::AbstractVector, J::AbstractVector, X::AbstractVector{T}; dup = BinaryOps.PLUS, nrows = maximum(I), ncols = maximum(J) ) where {T} A = GBMatrix{T}(nrows, ncols) @@ -33,14 +33,14 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = x. The resulting matrix is "iso-valued" such that it only stores `x` once rather than once for each index. """ -function GBMatrix(I::Vector, J::Vector, x::T; +function GBMatrix(I::AbstractVector, J::AbstractVector, x::T; nrows = maximum(I), ncols = maximum(J)) where {T} A = GBMatrix{T}(nrows, ncols) build(A, I, J, x) return A end -function build(A::GBMatrix{T}, I::Vector, J::Vector, x::T) where {T} +function build(A::GBMatrix{T}, I::AbstractVector, J::AbstractVector, x::T) where {T} nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build matrix with existing elements")) length(I) == length(J) || DimensionMismatch("I, J and X must have the same length") x = GBScalar(x) diff --git a/src/vector.jl b/src/vector.jl index e2b9fbcb..46552e01 100644 --- a/src/vector.jl +++ b/src/vector.jl @@ -14,7 +14,7 @@ GBVector{T}(dims::Dims{1}) where {T} = GBVector{T}(dims...) Create a GBVector from a vector of indices `I` and a vector of values `X`. """ -function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T} +function GBVector(I::AbstractVector, X::AbstractVector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T} x = GBVector{T}(nrows) build(x, I, X, dup = dup) return x @@ -27,14 +27,14 @@ Create an nrows length GBVector v such that M[I[k]] = x. The resulting vector is "iso-valued" such that it only stores `x` once rather than once for each index. """ -function GBVector(I::Vector, x::T; +function GBVector(I::AbstractVector, x::T; nrows = maximum(I)) where {T} A = GBVector{T}(nrows) build(A, I, x) return A end -function build(A::GBVector{T}, I::Vector, x::T) where {T} +function build(A::GBVector{T}, I::AbstractVector, x::T) where {T} nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build vector with existing elements")) x = GBScalar(x) diff --git a/test/chainrules/chainrulesutils.jl b/test/chainrules/chainrulesutils.jl new file mode 100644 index 00000000..e6f9a8bd --- /dev/null +++ b/test/chainrules/chainrulesutils.jl @@ -0,0 +1,17 @@ +using FiniteDifferences +function test_to_vec(x::T; check_inferred=true) where {T} + check_inferred && @inferred FiniteDifferences.to_vec(x) + x_vec, back = FiniteDifferences.to_vec(x) + @test x_vec isa Vector + @test all(s -> s isa Real, x_vec) + check_inferred && @inferred back(x_vec) + @test x == back(x_vec) + return nothing +end + +@testset "chainrulesutils" begin + y = GBMatrix(sprand(10, 10, 0.5)) + test_to_vec(y) + v = GBVector(sprand(10, 0.5)) + test_to_vec(v) +end diff --git a/test/chainrules/constructorrules.jl b/test/chainrules/constructorrules.jl new file mode 100644 index 00000000..83c9cd59 --- /dev/null +++ b/test/chainrules/constructorrules.jl @@ -0,0 +1,10 @@ +@testset "Construction Rules" begin + I = [1,1,2,2,3,4,4,5,6,7,7,7] + J = [2,4,5,7,6,1,3,6,3,3,4,5] + X = rand(12) + + test_frule(GBMatrix, I, J, X) + test_rrule(GBMatrix, I, J, X) + test_rrule(GBVector, I, X) + +end diff --git a/test/runtests.jl b/test/runtests.jl index 55726335..1aae6141 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,8 +12,11 @@ end println("Testing SuiteSparseGraphBLAS.jl") @testset "SuiteSparseGraphBLAS" begin + include_test("gbarray.jl") include_test("operations.jl") + include_test("chainrules/chainrulesutils.jl") include_test("chainrules/mulrules.jl") include_test("chainrules/mulrules.jl") + include_test("chainrules/constructorrules.jl") end From b665fa7e8e0202a127969b8300e5a62685c44c0c Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Fri, 2 Jul 2021 23:28:25 -0400 Subject: [PATCH 09/18] arithmetic groundwork --- Project.toml | 7 ++++-- src/SuiteSparseGraphBLAS.jl | 5 ++-- src/chainrules/arithmeticrules.jl | 26 +++++++++++++++++++ src/chainrules/chainruleutils.jl | 42 +++++++++++++++++++++++++++++++ src/operations/ewise.jl | 7 ++++++ 5 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 src/chainrules/arithmeticrules.jl create mode 100644 src/chainrules/chainruleutils.jl diff --git a/Project.toml b/Project.toml index 0b05d02a..7deeab49 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,10 @@ version = "0.4.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -15,8 +18,8 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -SSGraphBLAS_jll = "5.1.2" -julia = "1.6" CEnum = "0.4.1" ContextVariablesX = "0.1.1" MacroTools = "0.5.6" +SSGraphBLAS_jll = "5.1.2" +julia = "1.6" diff --git a/src/SuiteSparseGraphBLAS.jl b/src/SuiteSparseGraphBLAS.jl index f8727ac8..83d4e962 100644 --- a/src/SuiteSparseGraphBLAS.jl +++ b/src/SuiteSparseGraphBLAS.jl @@ -87,9 +87,10 @@ include("operations/kronecker.jl") include("print.jl") include("import.jl") include("export.jl") - -#EXPERIMENTAL include("options.jl") +#EXPERIMENTAL +include("chainrules/chainruleutils.jl") +include("chainrules/arithmeticrules.jl") #include("random.jl") include("misc.jl") export libgb diff --git a/src/chainrules/arithmeticrules.jl b/src/chainrules/arithmeticrules.jl new file mode 100644 index 00000000..7e547f08 --- /dev/null +++ b/src/chainrules/arithmeticrules.jl @@ -0,0 +1,26 @@ +import ChainRulesCore: frule, rrule +using ChainRulesCore +const RealOrComplex = Union{Real, Complex} +function frule( + (_, ΔA, ΔB), + ::typeof(mul), + A::GBArray, + B::GBArray +) + Ω = mul(A, B) + ∂Ω = mul(ΔA, B) + mul(A, ΔB) + return Ω, ∂Ω +end + +function rrule( + ::typeof(mul), + A::GBVecOrMat{T}, + B::GBVecOrMat{T} +) where {T <: RealOrComplex} + function mulpullback(ΔΩ) + ∂A = mul(ΔΩ, B') + ∂B = mul(A', ΔΩ) + return (NoTangent(), ∂A, ∂B) + end + return mul(A, B), mulpullback +end diff --git a/src/chainrules/chainruleutils.jl b/src/chainrules/chainruleutils.jl new file mode 100644 index 00000000..472fd51a --- /dev/null +++ b/src/chainrules/chainruleutils.jl @@ -0,0 +1,42 @@ +import FiniteDifferences: to_vec, rand_tangent +import LinearAlgebra: norm +#Required for ChainRulesTestUtils +function to_vec(M::GBMatrix) + I, J, X = findnz(M) + function backtomat(xvec) + return GBMatrix(I, J, xvec) + end + return X, backtomat +end + +function to_vec(v::GBVector) + i, x = findnz(v) + function backtovec(xvec) + return GBVector(i, xvec) + end + return x, backtovec +end + +function rand_tangent( + rng::AbstractRNG, + x::GBMatrix{T} +) where {T <: Union{AbstractFloat, Complex}} + n = nnz(x) + v = rand(rng, -9:0.01:9, n) + I, J, _ = findnz(x) + return GBMatrix(I, J, v) +end + +function rand_tangent( + rng::AbstractRNG, + x::GBVector{T} +) where {T <: Union{AbstractFloat, Complex}} + n = nnz(x) + v = rand(rng, -9:0.01:9, n) + I, _ = findnz(x) + return GBVector(I, v) +end + + +# LinearAlgebra.norm freaks over the nothings. +norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p) diff --git a/src/operations/ewise.jl b/src/operations/ewise.jl index 9a8399b0..bfbc49a1 100644 --- a/src/operations/ewise.jl +++ b/src/operations/ewise.jl @@ -275,6 +275,13 @@ function eadd( return eadd!(C, A, B, op; mask, accum, desc) end +function Base.:+(A::GBArray, B::GBArray) + eadd(A, B, nothing) +end + +function Base.:-(A::GBArray, B::GBArray) + eadd(A, B, BinaryOps.MINUS) +end #Elementwise Broadcasts ####################### From 0f4509e501213c685b5fafc1c11e92f616aac065 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Mon, 5 Jul 2021 23:54:24 -0400 Subject: [PATCH 10/18] arithmetic rules for mul and elwise 1st pass --- Project.toml | 1 + src/SuiteSparseGraphBLAS.jl | 6 +++- src/chainrules/arithmeticrules.jl | 26 ---------------- src/chainrules/chainruleutils.jl | 28 +++++++++-------- src/chainrules/ewiserules.jl | 51 +++++++++++++++++++++++++++++++ src/chainrules/maprules.jl | 17 +++++++++++ src/chainrules/mulrules.jl | 30 ++++++++++++++++++ src/chainrules/reducerules.jl | 0 src/chainrules/selectrules.jl | 0 src/matrix.jl | 3 +- src/operations/ewise.jl | 1 - src/operations/mul.jl | 1 - src/vector.jl | 4 +-- 13 files changed, 124 insertions(+), 44 deletions(-) delete mode 100644 src/chainrules/arithmeticrules.jl create mode 100644 src/chainrules/ewiserules.jl create mode 100644 src/chainrules/maprules.jl create mode 100644 src/chainrules/mulrules.jl create mode 100644 src/chainrules/reducerules.jl create mode 100644 src/chainrules/selectrules.jl diff --git a/Project.toml b/Project.toml index 7deeab49..ec8014d1 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.4.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5" diff --git a/src/SuiteSparseGraphBLAS.jl b/src/SuiteSparseGraphBLAS.jl index 83d4e962..c3800558 100644 --- a/src/SuiteSparseGraphBLAS.jl +++ b/src/SuiteSparseGraphBLAS.jl @@ -90,7 +90,11 @@ include("export.jl") include("options.jl") #EXPERIMENTAL include("chainrules/chainruleutils.jl") -include("chainrules/arithmeticrules.jl") +include("chainrules/mulrules.jl") +include("chainrules/ewiserules.jl") +include("chainrules/maprules.jl") +include("chainrules/reducerules.jl") +include("chainrules/selectrules.jl") #include("random.jl") include("misc.jl") export libgb diff --git a/src/chainrules/arithmeticrules.jl b/src/chainrules/arithmeticrules.jl deleted file mode 100644 index 7e547f08..00000000 --- a/src/chainrules/arithmeticrules.jl +++ /dev/null @@ -1,26 +0,0 @@ -import ChainRulesCore: frule, rrule -using ChainRulesCore -const RealOrComplex = Union{Real, Complex} -function frule( - (_, ΔA, ΔB), - ::typeof(mul), - A::GBArray, - B::GBArray -) - Ω = mul(A, B) - ∂Ω = mul(ΔA, B) + mul(A, ΔB) - return Ω, ∂Ω -end - -function rrule( - ::typeof(mul), - A::GBVecOrMat{T}, - B::GBVecOrMat{T} -) where {T <: RealOrComplex} - function mulpullback(ΔΩ) - ∂A = mul(ΔΩ, B') - ∂B = mul(A', ΔΩ) - return (NoTangent(), ∂A, ∂B) - end - return mul(A, B), mulpullback -end diff --git a/src/chainrules/chainruleutils.jl b/src/chainrules/chainruleutils.jl index 472fd51a..e09af671 100644 --- a/src/chainrules/chainruleutils.jl +++ b/src/chainrules/chainruleutils.jl @@ -1,42 +1,46 @@ -import FiniteDifferences: to_vec, rand_tangent -import LinearAlgebra: norm +import FiniteDifferences +import LinearAlgebra +import ChainRulesCore: frule, rrule +using ChainRulesCore +const RealOrComplex = Union{Real, Complex} + #Required for ChainRulesTestUtils -function to_vec(M::GBMatrix) +function FiniteDifferences.to_vec(M::GBMatrix) I, J, X = findnz(M) function backtomat(xvec) - return GBMatrix(I, J, xvec) + return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2)) end return X, backtomat end -function to_vec(v::GBVector) +function FiniteDifferences.to_vec(v::GBVector) i, x = findnz(v) function backtovec(xvec) - return GBVector(i, xvec) + return GBVector(i, xvec; nrows=size(v, 1)) end return x, backtovec end -function rand_tangent( +function FiniteDifferences.rand_tangent( rng::AbstractRNG, x::GBMatrix{T} ) where {T <: Union{AbstractFloat, Complex}} n = nnz(x) v = rand(rng, -9:0.01:9, n) I, J, _ = findnz(x) - return GBMatrix(I, J, v) + return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2)) end -function rand_tangent( +function FiniteDifferences.rand_tangent( rng::AbstractRNG, x::GBVector{T} ) where {T <: Union{AbstractFloat, Complex}} n = nnz(x) v = rand(rng, -9:0.01:9, n) I, _ = findnz(x) - return GBVector(I, v) + return GBVector(I, v; nrows = size(x, 1)) end - +FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent() # LinearAlgebra.norm freaks over the nothings. -norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p) +LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p) diff --git a/src/chainrules/ewiserules.jl b/src/chainrules/ewiserules.jl new file mode 100644 index 00000000..470dd894 --- /dev/null +++ b/src/chainrules/ewiserules.jl @@ -0,0 +1,51 @@ +#emul TIMES +function frule( + (_, ΔA, ΔB, _), + ::typeof(emul), + A::GBArray, + B::GBArray, + ::typeof(BinaryOps.TIMES) +) + Ω = emul(A, B, BinaryOps.TIMES) + ∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES) + return Ω, ∂Ω +end +function frule((_, ΔA, ΔB, _), ::typeof(emul), A::GBArray, B::GBArray) + return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES) +end + +function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES)) + function timespullback(ΔΩ) + ∂A = emul(ΔΩ, B) + ∂B = emul(ΔΩ, A) + return NoTangent(), ∂A, ∂B, NoTangent() + end + return emul(A, B, BinaryOps.TIMES), timespullback +end + +function rrule(::typeof(emul), A::GBArray, B::GBArray) + return rrule(emul, A, B, BinaryOps.TIMES) +end + +#eadd PLUS +function frule( + (_, ΔA, ΔB, _), + ::typeof(eadd), + A::GBArray, + B::GBArray, + ::typeof(BinaryOps.PLUS) +) + Ω = eadd(A, B, BinaryOps.PLUS) + ∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS) + return Ω, ∂Ω +end +function frule((_, ΔA, ΔB, _), ::typeof(eadd), A::GBArray, B::GBArray) + return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS) +end + +function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS)) + function pluspullback(ΔΩ) + return NoTangent(), ΔΩ, ΔΩ, NoTangent() + end + return eadd(A, B, BinaryOps.PLUS), pluspullback +end diff --git a/src/chainrules/maprules.jl b/src/chainrules/maprules.jl new file mode 100644 index 00000000..d4c30e3e --- /dev/null +++ b/src/chainrules/maprules.jl @@ -0,0 +1,17 @@ +# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather +# than AbstractOp. +#function rrule(map, f, xs) +# # Rather than 3 maps really want 1 multimap +# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x) +# ys = map(first, ys_and_pullbacks) +# pullbacks = map(last, ys_and_pullbacks) +# function map_pullback(dys) +# _call(f, x) = f(x) +# dfs_and_dxs = map(_call, pullbacks, dys) +# # but in your case you know it will be NoTangent() so can skip +# df = sum(first, dfs_and_dxs) +# dxs = map(last, dfs_and_dxs) +# return NoTangent(), df, dxs +# end +# return ys, map_pullback +#end diff --git a/src/chainrules/mulrules.jl b/src/chainrules/mulrules.jl new file mode 100644 index 00000000..4c358949 --- /dev/null +++ b/src/chainrules/mulrules.jl @@ -0,0 +1,30 @@ +# Standard arithmetic mul: +function frule( + (_, ΔA, ΔB), + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose +) + Ω = mul(A, B) + ∂Ω = mul(ΔA, B) + mul(A, ΔB) + return Ω, ∂Ω +end + +# Tests will not pass for this. For two reasons. +# First is #25, the output inference is not type stable. +# That's it's own issue. + +# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings. +# This means they give different answers. FiniteDifferences is probably "incorrect". +function rrule( + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose +) + function mulpullback(ΔΩ) + ∂A = mul(ΔΩ, B') + ∂B = mul(A', ΔΩ) + return NoTangent(), ∂A, ∂B + end + return mul(A, B), mulpullback +end diff --git a/src/chainrules/reducerules.jl b/src/chainrules/reducerules.jl new file mode 100644 index 00000000..e69de29b diff --git a/src/chainrules/selectrules.jl b/src/chainrules/selectrules.jl new file mode 100644 index 00000000..e69de29b diff --git a/src/matrix.jl b/src/matrix.jl index 490a67cd..648992d0 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -158,7 +158,8 @@ function Base.show(io::IO, ::MIME"text/plain", A::GBMatrix) gxbprint(io, A) end -SparseArrays.nonzeros(A::GBArray) = findnz(A)[3] +SparseArrays.nonzeros(A::GBArray) = findnz(A)[end] + # Indexing functions #################### diff --git a/src/operations/ewise.jl b/src/operations/ewise.jl index bfbc49a1..b30b2807 100644 --- a/src/operations/ewise.jl +++ b/src/operations/ewise.jl @@ -61,7 +61,6 @@ function emul!( desc = nothing ) op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES) - size(w) == size(u) == size(v) || throw(DimensionMismatch()) op = getoperator(op, optype(u, v)) accum = getoperator(accum, eltype(w)) diff --git a/src/operations/mul.jl b/src/operations/mul.jl index 319bd59b..cd258cd9 100644 --- a/src/operations/mul.jl +++ b/src/operations/mul.jl @@ -59,7 +59,6 @@ function LinearAlgebra.mul!( return w end - """ mul(A::GBArray, B::GBArray; kwargs...)::GBArray diff --git a/src/vector.jl b/src/vector.jl index b1210cba..e2b9fbcb 100644 --- a/src/vector.jl +++ b/src/vector.jl @@ -14,8 +14,8 @@ GBVector{T}(dims::Dims{1}) where {T} = GBVector{T}(dims...) Create a GBVector from a vector of indices `I` and a vector of values `X`. """ -function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS) where {T} - x = GBVector{T}(maximum(I)) +function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T} + x = GBVector{T}(nrows) build(x, I, X, dup = dup) return x end From b4ec8c540bd6d51a8660c56c8db705d1a3674e80 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Tue, 6 Jul 2021 02:00:58 -0400 Subject: [PATCH 11/18] tests and a few fixes --- src/chainrules/ewiserules.jl | 19 ++++++++++++++--- src/chainrules/mulrules.jl | 30 ++++++++++++++++++++++++--- src/lib/LibGraphBLAS.jl | 22 ++++++++++---------- test/runtests.jl | 3 ++- test/testrules.jl | 40 ++++++++++++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 18 deletions(-) create mode 100644 test/testrules.jl diff --git a/src/chainrules/ewiserules.jl b/src/chainrules/ewiserules.jl index 470dd894..e52bb4f3 100644 --- a/src/chainrules/ewiserules.jl +++ b/src/chainrules/ewiserules.jl @@ -10,7 +10,7 @@ function frule( ∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES) return Ω, ∂Ω end -function frule((_, ΔA, ΔB, _), ::typeof(emul), A::GBArray, B::GBArray) +function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray) return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES) end @@ -24,7 +24,12 @@ function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES) end function rrule(::typeof(emul), A::GBArray, B::GBArray) - return rrule(emul, A, B, BinaryOps.TIMES) + function timespullback(ΔΩ) + ∂A = emul(ΔΩ, B) + ∂B = emul(ΔΩ, A) + return NoTangent(), ∂A, ∂B + end + return emul(A, B, BinaryOps.TIMES), timespullback end #eadd PLUS @@ -39,7 +44,7 @@ function frule( ∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS) return Ω, ∂Ω end -function frule((_, ΔA, ΔB, _), ::typeof(eadd), A::GBArray, B::GBArray) +function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray) return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS) end @@ -49,3 +54,11 @@ function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS)) end return eadd(A, B, BinaryOps.PLUS), pluspullback end + +# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule. +function rrule(::typeof(eadd), A::GBArray, B::GBArray) + function pluspullback(ΔΩ) + return NoTangent(), ΔΩ, ΔΩ + end + return eadd(A, B, BinaryOps.PLUS), pluspullback +end diff --git a/src/chainrules/mulrules.jl b/src/chainrules/mulrules.jl index 4c358949..a7b4bc3e 100644 --- a/src/chainrules/mulrules.jl +++ b/src/chainrules/mulrules.jl @@ -4,18 +4,42 @@ function frule( ::typeof(mul), A::GBMatOrTranspose, B::GBMatOrTranspose +) + frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES) +end +function frule( + (_, ΔA, ΔB, _), + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose, + ::typeof(Semirings.PLUS_TIMES) ) Ω = mul(A, B) ∂Ω = mul(ΔA, B) + mul(A, ΔB) return Ω, ∂Ω end - # Tests will not pass for this. For two reasons. # First is #25, the output inference is not type stable. # That's it's own issue. # Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings. -# This means they give different answers. FiniteDifferences is probably "incorrect". +# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof. + +function rrule( + ::typeof(mul), + A::GBMatOrTranspose, + B::GBMatOrTranspose, + ::typeof(Semirings.PLUS_TIMES) +) + function mulpullback(ΔΩ) + ∂A = mul(ΔΩ, B') + ∂B = mul(A', ΔΩ) + return NoTangent(), ∂A, ∂B, NoTangent() + end + return mul(A, B), mulpullback +end + +# Do I have to duplicate this? :/ function rrule( ::typeof(mul), A::GBMatOrTranspose, @@ -26,5 +50,5 @@ function rrule( ∂B = mul(A', ΔΩ) return NoTangent(), ∂A, ∂B end - return mul(A, B), mulpullback +return mul(A, B), mulpullback end diff --git a/src/lib/LibGraphBLAS.jl b/src/lib/LibGraphBLAS.jl index 29615a77..0ccf0d22 100644 --- a/src/lib/LibGraphBLAS.jl +++ b/src/lib/LibGraphBLAS.jl @@ -27,27 +27,27 @@ macro wraperror(code) elseif info == GrB_NO_VALUE return nothing else - if info == GrB_UNINITIALIZED_OBJECT + if info == GrB_UNINITIALIZED_OBJECT throw(UninitializedObjectError) - elseif info == GrB_INVALID_OBJECT + elseif info == GrB_INVALID_OBJECT throw(InvalidObjectError) - elseif info == GrB_NULL_POINTER + elseif info == GrB_NULL_POINTER throw(NullPointerError) - elseif info == GrB_INVALID_VALUE + elseif info == GrB_INVALID_VALUE throw(InvalidValueError) - elseif info == GrB_INVALID_INDEX + elseif info == GrB_INVALID_INDEX throw(InvalidIndexError) - elseif info == GrB_DOMAIN_MISMATCH + elseif info == GrB_DOMAIN_MISMATCH throw(DomainError(nothing, "GraphBLAS Domain Mismatch")) elseif info == GrB_DIMENSION_MISMATCH throw(DimensionMismatch()) - elseif info == GrB_OUTPUT_NOT_EMPTY + elseif info == GrB_OUTPUT_NOT_EMPTY throw(OutputNotEmptyError) - elseif info == GrB_OUT_OF_MEMORY + elseif info == GrB_OUT_OF_MEMORY throw(OutOfMemoryError()) - elseif info == GrB_INSUFFICIENT_SPACE + elseif info == GrB_INSUFFICIENT_SPACE throw(InsufficientSpaceError) - elseif info == GrB_INDEX_OUT_OF_BOUNDS + elseif info == GrB_INDEX_OUT_OF_BOUNDS throw(BoundsError()) elseif info == GrB_PANIC throw(PANIC) @@ -843,7 +843,7 @@ for T ∈ valid_vec nvals = GrB_Vector_nvals(v) I = Vector{GrB_Index}(undef, nvals) X = Vector{$type}(undef, nvals) - nvals = Ref{GrB_Index}() + nvals = Ref{GrB_Index}(nvals) $func(I, X, nvals, v) nvals[] == length(I) == length(X) || throw(DimensionMismatch()) return I .+ 1, X diff --git a/test/runtests.jl b/test/runtests.jl index 0d7f65b0..dae1a95d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using SuiteSparseGraphBLAS using SparseArrays using Test using Random - +using ChainRulesTestUtils Random.seed!(1) function include_test(path) @@ -14,4 +14,5 @@ println("Testing SuiteSparseGraphBLAS.jl") @testset "SuiteSparseGraphBLAS" begin include_test("gbarray.jl") include_test("operations.jl") + include_test("testrules.jl") end diff --git a/test/testrules.jl b/test/testrules.jl new file mode 100644 index 00000000..a638f628 --- /dev/null +++ b/test/testrules.jl @@ -0,0 +1,40 @@ +@testset "Dense" begin + @testset "arithmetic semiring" begin + #dense first + M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10)) + Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) + test_frule(mul, M, Y; check_inferred=false) + test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_rrule(mul, M, Y; check_inferred=false) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + X = GBMatrix(rand(-10.0:0.05:10.0, 10)) + test_frule(eadd, X, Y; check_inferred=false) + test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_rrule(eadd, X, Y; check_inferred=false) + test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_frule(emul, X, Y; check_inferred=false) + test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_rrule(emul, X, Y; check_inferred=false) + test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + end +end + +@testset "Sparse" begin + @testset "arithmetic semiring" begin + M = GBMatrix(sprand(10, 10, 0.5)) + Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector) + test_frule(mul, M, Y; check_inferred=false) + test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_rrule(mul, M, Y; check_inferred=false) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + X = GBMatrix(sprand(10, 0.5)) + test_frule(eadd, X, Y; check_inferred=false) + test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_rrule(eadd, X, Y; check_inferred=false) + test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_frule(emul, X, Y; check_inferred=false) + test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_rrule(emul, X, Y; check_inferred=false) + test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + end +end From 7991aa50e6f6cd3ca6c6cecef1997fb76424ffe9 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Fri, 9 Jul 2021 11:28:36 -0400 Subject: [PATCH 12/18] Add mask function, fix eadd(PLUS) --- src/chainrules/ewiserules.jl | 31 +++++++++++++++++++------------ src/operations/transpose.jl | 1 + 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/chainrules/ewiserules.jl b/src/chainrules/ewiserules.jl index e52bb4f3..59877379 100644 --- a/src/chainrules/ewiserules.jl +++ b/src/chainrules/ewiserules.jl @@ -24,15 +24,18 @@ function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES) end function rrule(::typeof(emul), A::GBArray, B::GBArray) - function timespullback(ΔΩ) - ∂A = emul(ΔΩ, B) - ∂B = emul(ΔΩ, A) - return NoTangent(), ∂A, ∂B - end - return emul(A, B, BinaryOps.TIMES), timespullback + Ω, fullpb = rrule(emul, A, B, BinaryOps.TIMES) + emulpb(ΔΩ) = fullpb(ΔΩ)[1:3] + return Ω, emulpb end -#eadd PLUS +############ +# eadd rules +############ + +# PLUS +###### + function frule( (_, ΔA, ΔB, _), ::typeof(eadd), @@ -50,15 +53,19 @@ end function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS)) function pluspullback(ΔΩ) - return NoTangent(), ΔΩ, ΔΩ, NoTangent() + return ( + NoTangent(), + mask(ΔΩ, A; structural = true), + mask(ΔΩ, B; structural = true), + NoTangent() + ) end return eadd(A, B, BinaryOps.PLUS), pluspullback end # Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule. function rrule(::typeof(eadd), A::GBArray, B::GBArray) - function pluspullback(ΔΩ) - return NoTangent(), ΔΩ, ΔΩ - end - return eadd(A, B, BinaryOps.PLUS), pluspullback + Ω, fullpb = rrule(eadd, A, B, BinaryOps.PLUS) + eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3] + return Ω, eaddpb end diff --git a/src/operations/transpose.jl b/src/operations/transpose.jl index 041c3fb2..f376c995 100644 --- a/src/operations/transpose.jl +++ b/src/operations/transpose.jl @@ -68,6 +68,7 @@ end mask!(C::GBArray, A::GBArray, mask::GBArray) Apply a mask to matrix `A`, storing the results in C. + """ function mask!(C::GBArray, A::GBArray, mask::GBArray; structural = false, complement = false) desc = Descriptors.T0 From bb4dc6e14e40e059a719e531016a20014234c4ce Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Fri, 9 Jul 2021 11:52:06 -0400 Subject: [PATCH 13/18] correct mul rrules --- src/chainrules/mulrules.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/chainrules/mulrules.jl b/src/chainrules/mulrules.jl index a7b4bc3e..3344f05d 100644 --- a/src/chainrules/mulrules.jl +++ b/src/chainrules/mulrules.jl @@ -32,23 +32,20 @@ function rrule( ::typeof(Semirings.PLUS_TIMES) ) function mulpullback(ΔΩ) - ∂A = mul(ΔΩ, B') - ∂B = mul(A', ΔΩ) + ∂A = mul(ΔΩ, B'; mask=A) + ∂B = mul(A', ΔΩ; mask=B) return NoTangent(), ∂A, ∂B, NoTangent() end return mul(A, B), mulpullback end -# Do I have to duplicate this? :/ + function rrule( ::typeof(mul), A::GBMatOrTranspose, B::GBMatOrTranspose ) - function mulpullback(ΔΩ) - ∂A = mul(ΔΩ, B') - ∂B = mul(A', ΔΩ) - return NoTangent(), ∂A, ∂B - end -return mul(A, B), mulpullback + Ω, mulpullback = rrule(mul, A, B, Semirings.PLUS_TIMES) + pullback(ΔΩ) = mulpullback(ΔΩ)[1:3] +return Ω, pullback end From fd8433bf04c2a574f88010b054f8faf4625ba876 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Fri, 9 Jul 2021 20:27:26 -0400 Subject: [PATCH 14/18] test folder structure --- test/chainrules/ewiserules.jl | 32 ++++++++++++++++++++++++++++ test/chainrules/mulrules.jl | 21 ++++++++++++++++++ test/runtests.jl | 3 ++- test/testrules.jl | 40 ----------------------------------- 4 files changed, 55 insertions(+), 41 deletions(-) create mode 100644 test/chainrules/ewiserules.jl create mode 100644 test/chainrules/mulrules.jl delete mode 100644 test/testrules.jl diff --git a/test/chainrules/ewiserules.jl b/test/chainrules/ewiserules.jl new file mode 100644 index 00000000..f9d60905 --- /dev/null +++ b/test/chainrules/ewiserules.jl @@ -0,0 +1,32 @@ +@testset "Elementwise" begin + @testset "Dense" begin + @testset "Arithmetic Semiring" begin + #dense first + Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) + X = GBMatrix(rand(-10.0:0.05:10.0, 10)) + test_frule(eadd, X, Y; check_inferred=false) + test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_rrule(eadd, X, Y; check_inferred=false) + test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_frule(emul, X, Y; check_inferred=false) + test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_rrule(emul, X, Y; check_inferred=false) + test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + end + end + + @testset "Sparse" begin + @testset "Arithmetic Semiring" begin + Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector) + X = GBMatrix(sprand(10, 0.5)) + test_frule(eadd, X, Y; check_inferred=false) + test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_rrule(eadd, X, Y; check_inferred=false) + test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) + test_frule(emul, X, Y; check_inferred=false) + test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + test_rrule(emul, X, Y; check_inferred=false) + test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) + end + end +end diff --git a/test/chainrules/mulrules.jl b/test/chainrules/mulrules.jl new file mode 100644 index 00000000..ebd7e165 --- /dev/null +++ b/test/chainrules/mulrules.jl @@ -0,0 +1,21 @@ +@testset "mul" begin + @testset "Dense" begin + @testset "Arithmetic Semiring" begin + M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10)) + Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) + test_frule(mul, M, Y; check_inferred=false) + test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_rrule(mul, M, Y; check_inferred=false) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + end + end + + @testset "Sparse" begin + M = GBMatrix(sprand(100, 10, 0.25)) + Y = GBMatrix(sprand(10, 0.1)) #using matrix for now until I work out transpose(v::GBVector) + test_frule(mul, M, Y; check_inferred=false) + test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + test_rrule(mul, M, Y; check_inferred=false) + test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index dae1a95d..55726335 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,5 +14,6 @@ println("Testing SuiteSparseGraphBLAS.jl") @testset "SuiteSparseGraphBLAS" begin include_test("gbarray.jl") include_test("operations.jl") - include_test("testrules.jl") + include_test("chainrules/mulrules.jl") + include_test("chainrules/mulrules.jl") end diff --git a/test/testrules.jl b/test/testrules.jl deleted file mode 100644 index a638f628..00000000 --- a/test/testrules.jl +++ /dev/null @@ -1,40 +0,0 @@ -@testset "Dense" begin - @testset "arithmetic semiring" begin - #dense first - M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10)) - Y = GBMatrix(rand(-10.0:0.05:10.0, 10)) - test_frule(mul, M, Y; check_inferred=false) - test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) - test_rrule(mul, M, Y; check_inferred=false) - test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) - X = GBMatrix(rand(-10.0:0.05:10.0, 10)) - test_frule(eadd, X, Y; check_inferred=false) - test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_rrule(eadd, X, Y; check_inferred=false) - test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_frule(emul, X, Y; check_inferred=false) - test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) - test_rrule(emul, X, Y; check_inferred=false) - test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) - end -end - -@testset "Sparse" begin - @testset "arithmetic semiring" begin - M = GBMatrix(sprand(10, 10, 0.5)) - Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector) - test_frule(mul, M, Y; check_inferred=false) - test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) - test_rrule(mul, M, Y; check_inferred=false) - test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false) - X = GBMatrix(sprand(10, 0.5)) - test_frule(eadd, X, Y; check_inferred=false) - test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_rrule(eadd, X, Y; check_inferred=false) - test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false) - test_frule(emul, X, Y; check_inferred=false) - test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) - test_rrule(emul, X, Y; check_inferred=false) - test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false) - end -end From 965a98399650c287d98d68ae926516568062c15a Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Sat, 10 Jul 2021 11:30:28 -0400 Subject: [PATCH 15/18] Broken constructor rules --- src/SuiteSparseGraphBLAS.jl | 1 + src/chainrules/constructorrules.jl | 18 ++++++++++++++++++ src/matrix.jl | 6 +++--- src/vector.jl | 6 +++--- test/chainrules/chainrulesutils.jl | 17 +++++++++++++++++ test/chainrules/constructorrules.jl | 10 ++++++++++ test/runtests.jl | 3 +++ 7 files changed, 55 insertions(+), 6 deletions(-) create mode 100644 src/chainrules/constructorrules.jl create mode 100644 test/chainrules/chainrulesutils.jl create mode 100644 test/chainrules/constructorrules.jl diff --git a/src/SuiteSparseGraphBLAS.jl b/src/SuiteSparseGraphBLAS.jl index c3800558..f6c023dc 100644 --- a/src/SuiteSparseGraphBLAS.jl +++ b/src/SuiteSparseGraphBLAS.jl @@ -95,6 +95,7 @@ include("chainrules/ewiserules.jl") include("chainrules/maprules.jl") include("chainrules/reducerules.jl") include("chainrules/selectrules.jl") +include("chainrules/constructorrules.jl") #include("random.jl") include("misc.jl") export libgb diff --git a/src/chainrules/constructorrules.jl b/src/chainrules/constructorrules.jl new file mode 100644 index 00000000..7adce77c --- /dev/null +++ b/src/chainrules/constructorrules.jl @@ -0,0 +1,18 @@ +function frule( + (_, _, _, ΔX), + ::Type{GBMatrix}, + I::AbstractVector{<:Integer}, + J::AbstractVector{<:Integer}, + X::AbstractVector +) + return GBMatrix(I, J, X), GBMatrix(I, J, ΔX) +end +function rrule(::Type{GBMatrix}, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, X::AbstractVector) + gbmatrix_pullback(Δgbmatrix) = NoTangent(), NoTangent(), NoTangent(), nonzeros(Δgbmatrix) + return GBMatrix(I, J, X), gbmatrix_pullback +end + +function rrule(::Type{GBVector}, I::AbstractVector{<:Integer}, X::AbstractVector) + gbvector_pullback(Δgbvector) = NoTangent(), NoTangent(), nonzeros(Δgbvector) + return GBVector(I, X), gbvector_pullback +end diff --git a/src/matrix.jl b/src/matrix.jl index 648992d0..4a809e3e 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -18,7 +18,7 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = X[k]. The dup funct to `|` for booleans and `+` for nonbooleans. """ function GBMatrix( - I::Vector, J::Vector, X::Vector{T}; + I::AbstractVector, J::AbstractVector, X::AbstractVector{T}; dup = BinaryOps.PLUS, nrows = maximum(I), ncols = maximum(J) ) where {T} A = GBMatrix{T}(nrows, ncols) @@ -33,14 +33,14 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = x. The resulting matrix is "iso-valued" such that it only stores `x` once rather than once for each index. """ -function GBMatrix(I::Vector, J::Vector, x::T; +function GBMatrix(I::AbstractVector, J::AbstractVector, x::T; nrows = maximum(I), ncols = maximum(J)) where {T} A = GBMatrix{T}(nrows, ncols) build(A, I, J, x) return A end -function build(A::GBMatrix{T}, I::Vector, J::Vector, x::T) where {T} +function build(A::GBMatrix{T}, I::AbstractVector, J::AbstractVector, x::T) where {T} nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build matrix with existing elements")) length(I) == length(J) || DimensionMismatch("I, J and X must have the same length") x = GBScalar(x) diff --git a/src/vector.jl b/src/vector.jl index e2b9fbcb..46552e01 100644 --- a/src/vector.jl +++ b/src/vector.jl @@ -14,7 +14,7 @@ GBVector{T}(dims::Dims{1}) where {T} = GBVector{T}(dims...) Create a GBVector from a vector of indices `I` and a vector of values `X`. """ -function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T} +function GBVector(I::AbstractVector, X::AbstractVector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T} x = GBVector{T}(nrows) build(x, I, X, dup = dup) return x @@ -27,14 +27,14 @@ Create an nrows length GBVector v such that M[I[k]] = x. The resulting vector is "iso-valued" such that it only stores `x` once rather than once for each index. """ -function GBVector(I::Vector, x::T; +function GBVector(I::AbstractVector, x::T; nrows = maximum(I)) where {T} A = GBVector{T}(nrows) build(A, I, x) return A end -function build(A::GBVector{T}, I::Vector, x::T) where {T} +function build(A::GBVector{T}, I::AbstractVector, x::T) where {T} nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build vector with existing elements")) x = GBScalar(x) diff --git a/test/chainrules/chainrulesutils.jl b/test/chainrules/chainrulesutils.jl new file mode 100644 index 00000000..e6f9a8bd --- /dev/null +++ b/test/chainrules/chainrulesutils.jl @@ -0,0 +1,17 @@ +using FiniteDifferences +function test_to_vec(x::T; check_inferred=true) where {T} + check_inferred && @inferred FiniteDifferences.to_vec(x) + x_vec, back = FiniteDifferences.to_vec(x) + @test x_vec isa Vector + @test all(s -> s isa Real, x_vec) + check_inferred && @inferred back(x_vec) + @test x == back(x_vec) + return nothing +end + +@testset "chainrulesutils" begin + y = GBMatrix(sprand(10, 10, 0.5)) + test_to_vec(y) + v = GBVector(sprand(10, 0.5)) + test_to_vec(v) +end diff --git a/test/chainrules/constructorrules.jl b/test/chainrules/constructorrules.jl new file mode 100644 index 00000000..83c9cd59 --- /dev/null +++ b/test/chainrules/constructorrules.jl @@ -0,0 +1,10 @@ +@testset "Construction Rules" begin + I = [1,1,2,2,3,4,4,5,6,7,7,7] + J = [2,4,5,7,6,1,3,6,3,3,4,5] + X = rand(12) + + test_frule(GBMatrix, I, J, X) + test_rrule(GBMatrix, I, J, X) + test_rrule(GBVector, I, X) + +end diff --git a/test/runtests.jl b/test/runtests.jl index 55726335..1aae6141 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,8 +12,11 @@ end println("Testing SuiteSparseGraphBLAS.jl") @testset "SuiteSparseGraphBLAS" begin + include_test("gbarray.jl") include_test("operations.jl") + include_test("chainrules/chainrulesutils.jl") include_test("chainrules/mulrules.jl") include_test("chainrules/mulrules.jl") + include_test("chainrules/constructorrules.jl") end From 9369b60014a406cf784848875060d748df2576a2 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Sat, 10 Jul 2021 22:02:50 -0400 Subject: [PATCH 16/18] Move out constructor rules for now --- src/chainrules/constructorrules.jl | 18 ------------------ test/chainrules/constructorrules.jl | 10 ---------- 2 files changed, 28 deletions(-) delete mode 100644 src/chainrules/constructorrules.jl delete mode 100644 test/chainrules/constructorrules.jl diff --git a/src/chainrules/constructorrules.jl b/src/chainrules/constructorrules.jl deleted file mode 100644 index 7adce77c..00000000 --- a/src/chainrules/constructorrules.jl +++ /dev/null @@ -1,18 +0,0 @@ -function frule( - (_, _, _, ΔX), - ::Type{GBMatrix}, - I::AbstractVector{<:Integer}, - J::AbstractVector{<:Integer}, - X::AbstractVector -) - return GBMatrix(I, J, X), GBMatrix(I, J, ΔX) -end -function rrule(::Type{GBMatrix}, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, X::AbstractVector) - gbmatrix_pullback(Δgbmatrix) = NoTangent(), NoTangent(), NoTangent(), nonzeros(Δgbmatrix) - return GBMatrix(I, J, X), gbmatrix_pullback -end - -function rrule(::Type{GBVector}, I::AbstractVector{<:Integer}, X::AbstractVector) - gbvector_pullback(Δgbvector) = NoTangent(), NoTangent(), nonzeros(Δgbvector) - return GBVector(I, X), gbvector_pullback -end diff --git a/test/chainrules/constructorrules.jl b/test/chainrules/constructorrules.jl deleted file mode 100644 index 83c9cd59..00000000 --- a/test/chainrules/constructorrules.jl +++ /dev/null @@ -1,10 +0,0 @@ -@testset "Construction Rules" begin - I = [1,1,2,2,3,4,4,5,6,7,7,7] - J = [2,4,5,7,6,1,3,6,3,3,4,5] - X = rand(12) - - test_frule(GBMatrix, I, J, X) - test_rrule(GBMatrix, I, J, X) - test_rrule(GBVector, I, X) - -end From 9c6f4788f6254b104eae7875f1e86f439d22fe23 Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Sat, 10 Jul 2021 22:10:50 -0400 Subject: [PATCH 17/18] compat --- Project.toml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index ec8014d1..5bc3e979 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.4.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" -ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5" @@ -19,8 +18,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -CEnum = "0.4.1" -ContextVariablesX = "0.1.1" -MacroTools = "0.5.6" -SSGraphBLAS_jll = "5.1.2" +CEnum = "0.4" +ContextVariablesX = "0.1" +MacroTools = "0.5" +SSGraphBLAS_jll = "5.1" julia = "1.6" +ChainRulesCore = "0.10" +ChainRulesTestUtils = "0.7" +FiniteDifferences = "0.12" From c769833abdf9b2d66f55c09ffdef3fc7eab3b55f Mon Sep 17 00:00:00 2001 From: Wimmerer Date: Sat, 10 Jul 2021 22:11:42 -0400 Subject: [PATCH 18/18] rm constructorrule includes --- src/SuiteSparseGraphBLAS.jl | 1 - test/runtests.jl | 1 - 2 files changed, 2 deletions(-) diff --git a/src/SuiteSparseGraphBLAS.jl b/src/SuiteSparseGraphBLAS.jl index f6c023dc..c3800558 100644 --- a/src/SuiteSparseGraphBLAS.jl +++ b/src/SuiteSparseGraphBLAS.jl @@ -95,7 +95,6 @@ include("chainrules/ewiserules.jl") include("chainrules/maprules.jl") include("chainrules/reducerules.jl") include("chainrules/selectrules.jl") -include("chainrules/constructorrules.jl") #include("random.jl") include("misc.jl") export libgb diff --git a/test/runtests.jl b/test/runtests.jl index 1aae6141..6b6d2cb4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,5 +18,4 @@ println("Testing SuiteSparseGraphBLAS.jl") include_test("chainrules/chainrulesutils.jl") include_test("chainrules/mulrules.jl") include_test("chainrules/mulrules.jl") - include_test("chainrules/constructorrules.jl") end