Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mul/ewise rules for basic arithmetic semiring #26

Merged
merged 20 commits into from
Jul 11, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ version = "0.4.0"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
rayegun marked this conversation as resolved.
Show resolved Hide resolved
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -15,8 +19,8 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
rayegun marked this conversation as resolved.
Show resolved Hide resolved
SSGraphBLAS_jll = "5.1.2"
julia = "1.6"
CEnum = "0.4.1"
ContextVariablesX = "0.1.1"
MacroTools = "0.5.6"
SSGraphBLAS_jll = "5.1.2"
julia = "1.6"
9 changes: 7 additions & 2 deletions src/SuiteSparseGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,14 @@ include("operations/kronecker.jl")
include("print.jl")
include("import.jl")
include("export.jl")

#EXPERIMENTAL
include("options.jl")
#EXPERIMENTAL
include("chainrules/chainruleutils.jl")
include("chainrules/mulrules.jl")
include("chainrules/ewiserules.jl")
include("chainrules/maprules.jl")
include("chainrules/reducerules.jl")
include("chainrules/selectrules.jl")
#include("random.jl")
include("misc.jl")
export libgb
Expand Down
46 changes: 46 additions & 0 deletions src/chainrules/chainruleutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import FiniteDifferences
import LinearAlgebra
import ChainRulesCore: frule, rrule
using ChainRulesCore
const RealOrComplex = Union{Real, Complex}

#Required for ChainRulesTestUtils
function FiniteDifferences.to_vec(M::GBMatrix)
rayegun marked this conversation as resolved.
Show resolved Hide resolved
I, J, X = findnz(M)
function backtomat(xvec)
return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2))
end
return X, backtomat
end
rayegun marked this conversation as resolved.
Show resolved Hide resolved

function FiniteDifferences.to_vec(v::GBVector)
i, x = findnz(v)
function backtovec(xvec)
return GBVector(i, xvec; nrows=size(v, 1))
end
return x, backtovec
end

function FiniteDifferences.rand_tangent(
rng::AbstractRNG,
x::GBMatrix{T}
) where {T <: Union{AbstractFloat, Complex}}
n = nnz(x)
v = rand(rng, -9:0.01:9, n)
I, J, _ = findnz(x)
return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2))
end

function FiniteDifferences.rand_tangent(
rng::AbstractRNG,
x::GBVector{T}
) where {T <: Union{AbstractFloat, Complex}}
n = nnz(x)
v = rand(rng, -9:0.01:9, n)
I, _ = findnz(x)
return GBVector(I, v; nrows = size(x, 1))
end

FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent()
# LinearAlgebra.norm freaks over the nothings.
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p)
64 changes: 64 additions & 0 deletions src/chainrules/ewiserules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#emul TIMES
function frule(
(_, ΔA, ΔB, _),
::typeof(emul),
A::GBArray,
B::GBArray,
::typeof(BinaryOps.TIMES)
)
Ω = emul(A, B, BinaryOps.TIMES)
∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray)
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES)
end

function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES))
function timespullback(ΔΩ)
∂A = emul(ΔΩ, B)
∂B = emul(ΔΩ, A)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return emul(A, B, BinaryOps.TIMES), timespullback
end

function rrule(::typeof(emul), A::GBArray, B::GBArray)
function timespullback(ΔΩ)
∂A = emul(ΔΩ, B)
∂B = emul(ΔΩ, A)
return NoTangent(), ∂A, ∂B
end
return emul(A, B, BinaryOps.TIMES), timespullback
end

#eadd PLUS
function frule(
(_, ΔA, ΔB, _),
::typeof(eadd),
A::GBArray,
B::GBArray,
::typeof(BinaryOps.PLUS)
)
Ω = eadd(A, B, BinaryOps.PLUS)
∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray)
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS)
end

function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS))
function pluspullback(ΔΩ)
return NoTangent(), ΔΩ, ΔΩ, NoTangent()
end
return eadd(A, B, BinaryOps.PLUS), pluspullback
end

# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule.
rayegun marked this conversation as resolved.
Show resolved Hide resolved
function rrule(::typeof(eadd), A::GBArray, B::GBArray)
function pluspullback(ΔΩ)
return NoTangent(), ΔΩ, ΔΩ
end
return eadd(A, B, BinaryOps.PLUS), pluspullback
end
rayegun marked this conversation as resolved.
Show resolved Hide resolved
17 changes: 17 additions & 0 deletions src/chainrules/maprules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather
# than AbstractOp.
#function rrule(map, f, xs)
# # Rather than 3 maps really want 1 multimap
# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x)
# ys = map(first, ys_and_pullbacks)
# pullbacks = map(last, ys_and_pullbacks)
# function map_pullback(dys)
# _call(f, x) = f(x)
# dfs_and_dxs = map(_call, pullbacks, dys)
# # but in your case you know it will be NoTangent() so can skip
# df = sum(first, dfs_and_dxs)
# dxs = map(last, dfs_and_dxs)
# return NoTangent(), df, dxs
# end
# return ys, map_pullback
#end
54 changes: 54 additions & 0 deletions src/chainrules/mulrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Standard arithmetic mul:
function frule(
(_, ΔA, ΔB),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES)
end
function frule(
(_, ΔA, ΔB, _),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
Ω = mul(A, B)
∂Ω = mul(ΔA, B) + mul(A, ΔB)
return Ω, ∂Ω
end
# Tests will not pass for this. For two reasons.
# First is #25, the output inference is not type stable.
# That's it's own issue.

# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings.
# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof.

function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
function mulpullback(ΔΩ)
∂A = mul(ΔΩ, B')
∂B = mul(A', ΔΩ)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return mul(A, B), mulpullback
end

# Do I have to duplicate this? :/
function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
function mulpullback(ΔΩ)
∂A = mul(ΔΩ, B')
∂B = mul(A', ΔΩ)
return NoTangent(), ∂A, ∂B
end
return mul(A, B), mulpullback
end
Empty file added src/chainrules/reducerules.jl
Empty file.
Empty file added src/chainrules/selectrules.jl
Empty file.
22 changes: 11 additions & 11 deletions src/lib/LibGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,27 @@ macro wraperror(code)
elseif info == GrB_NO_VALUE
return nothing
else
if info == GrB_UNINITIALIZED_OBJECT
if info == GrB_UNINITIALIZED_OBJECT
throw(UninitializedObjectError)
elseif info == GrB_INVALID_OBJECT
elseif info == GrB_INVALID_OBJECT
throw(InvalidObjectError)
elseif info == GrB_NULL_POINTER
elseif info == GrB_NULL_POINTER
throw(NullPointerError)
elseif info == GrB_INVALID_VALUE
elseif info == GrB_INVALID_VALUE
throw(InvalidValueError)
elseif info == GrB_INVALID_INDEX
elseif info == GrB_INVALID_INDEX
throw(InvalidIndexError)
elseif info == GrB_DOMAIN_MISMATCH
elseif info == GrB_DOMAIN_MISMATCH
throw(DomainError(nothing, "GraphBLAS Domain Mismatch"))
elseif info == GrB_DIMENSION_MISMATCH
throw(DimensionMismatch())
elseif info == GrB_OUTPUT_NOT_EMPTY
elseif info == GrB_OUTPUT_NOT_EMPTY
throw(OutputNotEmptyError)
elseif info == GrB_OUT_OF_MEMORY
elseif info == GrB_OUT_OF_MEMORY
throw(OutOfMemoryError())
elseif info == GrB_INSUFFICIENT_SPACE
elseif info == GrB_INSUFFICIENT_SPACE
throw(InsufficientSpaceError)
elseif info == GrB_INDEX_OUT_OF_BOUNDS
elseif info == GrB_INDEX_OUT_OF_BOUNDS
throw(BoundsError())
elseif info == GrB_PANIC
throw(PANIC)
Expand Down Expand Up @@ -843,7 +843,7 @@ for T ∈ valid_vec
nvals = GrB_Vector_nvals(v)
I = Vector{GrB_Index}(undef, nvals)
X = Vector{$type}(undef, nvals)
nvals = Ref{GrB_Index}()
nvals = Ref{GrB_Index}(nvals)
$func(I, X, nvals, v)
nvals[] == length(I) == length(X) || throw(DimensionMismatch())
return I .+ 1, X
Expand Down
3 changes: 2 additions & 1 deletion src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ function Base.show(io::IO, ::MIME"text/plain", A::GBMatrix)
gxbprint(io, A)
end

SparseArrays.nonzeros(A::GBArray) = findnz(A)[3]
SparseArrays.nonzeros(A::GBArray) = findnz(A)[end]


# Indexing functions
####################
Expand Down
8 changes: 7 additions & 1 deletion src/operations/ewise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ function emul!(
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)

size(w) == size(u) == size(v) || throw(DimensionMismatch())
op = getoperator(op, optype(u, v))
accum = getoperator(accum, eltype(w))
Expand Down Expand Up @@ -291,6 +290,13 @@ function eadd(
return eadd!(C, A, B, op; mask, accum, desc)
end

function Base.:+(A::GBArray, B::GBArray)
eadd(A, B, nothing)
end

function Base.:-(A::GBArray, B::GBArray)
eadd(A, B, BinaryOps.MINUS)
end
#Elementwise Broadcasts
#######################

Expand Down
1 change: 0 additions & 1 deletion src/operations/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ function LinearAlgebra.mul!(
return w
end


"""
mul(A::GBArray, B::GBArray; kwargs...)::GBArray

Expand Down
4 changes: 2 additions & 2 deletions src/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ GBVector{T}(dims::Dims{1}) where {T} = GBVector{T}(dims...)

Create a GBVector from a vector of indices `I` and a vector of values `X`.
"""
function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS) where {T}
x = GBVector{T}(maximum(I))
function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T}
x = GBVector{T}(nrows)
build(x, I, X, dup = dup)
return x
end
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using SuiteSparseGraphBLAS
using SparseArrays
using Test
using Random

using ChainRulesTestUtils
Random.seed!(1)

function include_test(path)
Expand All @@ -14,4 +14,5 @@ println("Testing SuiteSparseGraphBLAS.jl")
@testset "SuiteSparseGraphBLAS" begin
include_test("gbarray.jl")
include_test("operations.jl")
include_test("testrules.jl")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually the structure of test folder mirrors the src/ folder, which makes it easier to find things when the package grows.

end
40 changes: 40 additions & 0 deletions test/testrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
@testset "Dense" begin
@testset "arithmetic semiring" begin
#dense first
M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10))
Y = GBMatrix(rand(-10.0:0.05:10.0, 10))
test_frule(mul, M, Y; check_inferred=false)
test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
test_rrule(mul, M, Y; check_inferred=false)
test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
X = GBMatrix(rand(-10.0:0.05:10.0, 10))
test_frule(eadd, X, Y; check_inferred=false)
test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_rrule(eadd, X, Y; check_inferred=false)
test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_frule(emul, X, Y; check_inferred=false)
test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
test_rrule(emul, X, Y; check_inferred=false)
test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
end
end

@testset "Sparse" begin
@testset "arithmetic semiring" begin
M = GBMatrix(sprand(10, 10, 0.5))
Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector)
test_frule(mul, M, Y; check_inferred=false)
test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
test_rrule(mul, M, Y; check_inferred=false)
test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
X = GBMatrix(sprand(10, 0.5))
test_frule(eadd, X, Y; check_inferred=false)
test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_rrule(eadd, X, Y; check_inferred=false)
test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_frule(emul, X, Y; check_inferred=false)
test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
test_rrule(emul, X, Y; check_inferred=false)
test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
end
end