Skip to content

Commit

Permalink
iteration and broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 17, 2024
1 parent 909df5a commit 4e88944
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 1,670 deletions.
23 changes: 23 additions & 0 deletions GNNGraphs/docs/src/guides/temporalgraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ GNNGraph:
num_nodes: 10
num_edges: 16
```
## Iteration and Broadcasting

Iteration and broadcasting over a temporal graph is similar to that of a vector of snapshots:

```jldoctest temporal
julia> snapshots = [rand_graph(10, 20), rand_graph(10, 14), rand_graph(10, 22)];
julia> tg = TemporalSnapshotsGNNGraph(snapshots);
julia> [g for g in tg] # iterate over snapshots
3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
GNNGraph(10, 20) with no data
GNNGraph(10, 14) with no data
GNNGraph(10, 22) with no data
julia> f(g) = g isa GNNGraph;
julia> f.(tg) # broadcast over snapshots
3-element BitVector:
1
1
1
```

## Basic Queries

Expand Down
8 changes: 7 additions & 1 deletion GNNGraphs/src/temporalsnapshotsgnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ end

function Base.length(tg::TemporalSnapshotsGNNGraph)
return tg.num_snapshots
end
end

# Allow broadcasting over the temporal snapshots
Base.broadcastable(tg::TemporalSnapshotsGNNGraph) = tg.snapshots

Base.iterate(tg::TemporalSnapshotsGNNGraph) = Base.iterate(tg.snapshots)
Base.iterate(tg::TemporalSnapshotsGNNGraph, i) = Base.iterate(tg.snapshots, i)

function Base.setindex!(tg::TemporalSnapshotsGNNGraph, g::GNNGraph, t::Int)
tg.snapshots[t] = g
Expand Down
30 changes: 23 additions & 7 deletions GNNGraphs/test/temporalsnapshotsgnngraph.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#TODO add graph_type = GRAPH_TYPE to all constructor calls

@testset "Constructor array TemporalSnapshotsGNNGraph" begin
snapshots = [rand_graph(10, 20) for i in 1:5]
tg = TemporalSnapshotsGNNGraph(snapshots)
Expand All @@ -12,6 +14,7 @@
@test tg.num_snapshots == 5
end


@testset "==" begin
snapshots = [rand_graph(10, 20) for i in 1:5]
tsg1 = TemporalSnapshotsGNNGraph(snapshots)
Expand Down Expand Up @@ -41,7 +44,7 @@ end
end

@testset "getproperty" begin
x = rand(10)
x = rand(Float32, 10)
snapshots = [rand_graph(10, 20, ndata = x) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
@test tsg.tgdata == DataStore()
Expand Down Expand Up @@ -111,18 +114,31 @@ end
@testset "show" begin
snapshots = [rand_graph(10, 20) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with no data"
@test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5) with no data"
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5)"
@test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5)"
@test sprint(show, MIME("text/plain"), tsg; context=:compact => false) == "TemporalSnapshotsGNNGraph:\n num_nodes: [10, 10, 10, 10, 10]\n num_edges: [20, 20, 20, 20, 20]\n num_snapshots: 5"
tsg.tgdata.x=rand(4)
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data"
tsg.tgdata.x = rand(Float32, 4)
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5)"
end

@testset "broadcastable" begin
snapshots = [rand_graph(10, 20) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
f(g) = g isa GNNGraph
@test f.(tsg) == trues(5)
end

@testset "iterate" begin
snapshots = [rand_graph(10, 20) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
@test [g for g in tsg] isa Vector{<:GNNGraph}
end

if TEST_GPU
@testset "gpu" begin
snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5]
snapshots = [rand_graph(10, 20; ndata = rand(Float32, 5,10)) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
tsg.tgdata.x = rand(5)
tsg.tgdata.x = rand(Float32, 5)
dev = CUDADevice() #TODO replace with `gpu_device()`
tsg = tsg |> dev
@test tsg.snapshots[1].ndata.x isa CuArray
Expand Down
2 changes: 1 addition & 1 deletion GNNLux/docs/src_tutorials/gnn_intro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ visualize_embeddings(emb_init, colors = labels)
# If you are not new to Lux, this scheme should appear familiar to you.

# Note that our semi-supervised learning scenario is achieved by the following line:
# ```
# ```julia
# logitcrossentropy(ŷ[:,train_mask], y[:,train_mask])
# ```
# While we compute node embeddings for all of our nodes, we **only make use of the training nodes for computing the loss**.
Expand Down
File renamed without changes.
Loading

0 comments on commit 4e88944

Please sign in to comment.