Skip to content

Commit

Permalink
Add DConv layer (#441)
Browse files Browse the repository at this point in the history
* First `DConv` draft

* CUDA friendly

* Add GNNLayer

* Add test

* Export `DConv`

* Fix

* Fix test

* Add docs

* Add type

Co-authored-by: Carlo Lucibello <[email protected]>

* Add propagate but need to fix the transpose part

* Fix transpose

* Add spaces

Co-authored-by: Carlo Lucibello <[email protected]>

* Add spaces

Co-authored-by: Carlo Lucibello <[email protected]>

* Add spaces

Co-authored-by: Carlo Lucibello <[email protected]>

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
aurorarossi and CarloLucibello authored Jul 18, 2024
1 parent acf4b6a commit df56b7e
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export
SGConv,
TAGConv,
TransformerConv,
DConv,

# layers/heteroconv
HeteroGraphConv,
Expand Down
77 changes: 77 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2078,3 +2078,80 @@ function Base.show(io::IO, l::TransformerConv)
(in, ein), out = l.channels
print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")
end

"""
DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true)
Diffusion convolution layer from the paper [Diffusion Convolutional Recurrent Neural Networks: Data-Driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926).
# Arguments
- `ch`: Pair of input and output dimensions.
- `K`: Number of diffusion steps.
- `init`: Weights' initializer. Default `glorot_uniform`.
- `bias`: Add learnable bias. Default `true`.
# Examples
```
julia> g = GNNGraph(rand(10, 10), ndata = rand(Float32, 2, 10));
julia> dconv = DConv(2 => 4, 4)
DConv(2 => 4, K=4)
julia> y = dconv(g, g.ndata.x);
julia> size(y)
(4, 10)
```
"""
struct DConv <: GNNLayer
in::Int
out::Int
weights::AbstractArray
bias::AbstractArray
K::Int
end

@functor DConv

function DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true)
in, out = ch
weights = init(2, K, out, in)
b = bias ? Flux.create_bias(weights, true, out) : false
DConv(in, out, weights, b, K)
end

function (l::DConv)(g::GNNGraph, x::AbstractMatrix)
#A = adjacency_matrix(g, weighted = true)
s, t = edge_index(g)
gt = GNNGraph(t, s, get_edge_weight(g))
deg_out = degree(g; dir = :out)
deg_in = degree(g; dir = :in)
deg_out = Diagonal(deg_out)
deg_in = Diagonal(deg_in)

h = l.weights[1,1,:,:] * x .+ l.weights[2,1,:,:] * x

T0 = x
if l.K > 1
# T1_in = T0 * deg_in * A'
#T1_out = T0 * deg_out' * A
T1_out = propagate(w_mul_xj, g, +; xj = T0*deg_out')
T1_in = propagate(w_mul_xj, gt, +; xj = T0*deg_in)
h = h .+ l.weights[1,2,:,:] * T1_in .+ l.weights[2,2,:,:] * T1_out
end
for i in 2:l.K
T2_in = propagate(w_mul_xj, gt, +; xj = T1_in*deg_in)
T2_in = 2 * T2_in - T0
T2_out = propagate(w_mul_xj, g ,+; xj = T1_out*deg_out')
T2_out = 2 * T2_out - T0
h = h .+ l.weights[1,i,:,:] * T2_in .+ l.weights[2,i,:,:] * T2_out
T1_in = T2_in
T1_out = T2_out
end
return h .+ l.bias
end

function Base.show(io::IO, l::DConv)
print(io, "DConv($(l.in) => $(l.out), K=$(l.K))")
end
10 changes: 10 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,13 @@ end
outsize = (in_channel, g.num_nodes))
end
end

@testset "DConv" begin
K = [1, 2, 3] # for different number of hops
for k in K
l = DConv(in_channel => out_channel, k)
for g in test_graphs
test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes))
end
end
end

0 comments on commit df56b7e

Please sign in to comment.