Skip to content

Commit

Permalink
Creating dropout functionality in the GATConv and GATv2Conv Layers (#411
Browse files Browse the repository at this point in the history
)

* Adding the dropout functionalities to GAT and GATV2

Signed-off-by: achiverram28 <[email protected]>

* Corrrecting dropout keyword

Signed-off-by: achiverram28 <[email protected]>

* Adding the test for dropout for GATConv and GATV2Conv

Signed-off-by: achiverram28 <[email protected]>

* Fix

Signed-off-by: achiverram28 <[email protected]>

* Fix in test

Signed-off-by: achiverram28 <[email protected]>

---------

Signed-off-by: achiverram28 <[email protected]>
  • Loading branch information
achiverram28 authored Mar 16, 2024
1 parent 6f07d5b commit 18c4606
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 24 deletions.
20 changes: 13 additions & 7 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ and the attention coefficients will be calculated as
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`.
- `negative_slope`: The parameter of LeakyReLU.Default `0.2`.
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`.
- `dropout`: Dropout probability on the normalized attention coefficient. Default `0.0`.
# Examples
Expand All @@ -384,7 +384,7 @@ l = GATConv(in_channel => out_channel, add_self_loops = false, bias = false; hea
y = l(g, x)
```
"""
struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <:
struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, DV, T, A <: AbstractMatrix, F, B} <:
GNNLayer
dense_x::DX
dense_e::DE
Expand All @@ -396,6 +396,7 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix,
heads::Int
concat::Bool
add_self_loops::Bool
dropout::DV
end

@functor GATConv
Expand All @@ -405,7 +406,7 @@ GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args

function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
heads::Int = 1, concat::Bool = true, negative_slope = 0.2,
init = glorot_uniform, bias::Bool = true, add_self_loops = true)
init = glorot_uniform, bias::Bool = true, add_self_loops = true, dropout=0.0)
(in, ein), out = ch
if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
Expand All @@ -416,7 +417,7 @@ function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false
a = init(ein > 0 ? 3out : 2out, heads)
negative_slope = convert(Float32, negative_slope)
GATConv(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops)
GATConv(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, dropout)
end

(l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
Expand Down Expand Up @@ -448,6 +449,7 @@ function (l::GATConv)(g::AbstractGNNGraph, x,
# a hand-written message passing
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e)
α = softmax_edge_neighbors(g, m.logα)
α = dropout(α, l.dropout)
β = α .* m.Wxj
x = aggregate_neighbors(g, +, β)

Expand Down Expand Up @@ -518,6 +520,7 @@ and the attention coefficients will be calculated as
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`.
- `negative_slope`: The parameter of LeakyReLU.Default `0.2`.
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`.
- `dropout`: Dropout probability on the normalized attention coefficient. Default `0.0`.
# Examples
```julia
Expand All @@ -540,7 +543,7 @@ e = randn(Float32, ein, length(s))
y = l(g, x, e)
```
"""
struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer
struct GATv2Conv{T, A1, A2, A3, DV, B, C <: AbstractMatrix, F} <: GNNLayer
dense_i::A1
dense_j::A2
dense_e::A3
Expand All @@ -552,6 +555,7 @@ struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer
heads::Int
concat::Bool
add_self_loops::Bool
dropout::DV
end

@functor GATv2Conv
Expand All @@ -568,7 +572,8 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
negative_slope = 0.2,
init = glorot_uniform,
bias::Bool = true,
add_self_loops = true)
add_self_loops = true,
dropout=0.0)
(in, ein), out = ch

if add_self_loops
Expand All @@ -586,7 +591,7 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
a = init(out, heads)
negative_slope = convert(eltype(dense_i.weight), negative_slope)
GATv2Conv(dense_i, dense_j, dense_e, b, a, σ, negative_slope, ch, heads, concat,
add_self_loops)
add_self_loops, dropout)
end

(l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
Expand All @@ -611,6 +616,7 @@ function (l::GATv2Conv)(g::AbstractGNNGraph, x,

m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e)
α = softmax_edge_neighbors(g, m.logα)
α = dropout(α, l.dropout)
β = α .* m.Wxj
x = aggregate_neighbors(g, +, β)

Expand Down
25 changes: 8 additions & 17 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,21 @@ end

@testset "GATConv" begin
for heads in (1, 2), concat in (true, false)
l = GATConv(in_channel => out_channel; heads, concat)
l = GATConv(in_channel => out_channel; heads, concat, dropout=0)
for g in test_graphs
test_layer(l, g, rtol = RTOL_LOW,
exclude_grad_fields = [:negative_slope],
exclude_grad_fields = [:negative_slope, :dropout],
outsize = (concat ? heads * out_channel : out_channel,
g.num_nodes))
end
end

@testset "edge features" begin
ein = 3
l = GATConv((in_channel, ein) => out_channel, add_self_loops = false)
l = GATConv((in_channel, ein) => out_channel, add_self_loops = false, dropout=0)
g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges))
test_layer(l, g, rtol = RTOL_LOW,
exclude_grad_fields = [:negative_slope],
exclude_grad_fields = [:negative_slope, :dropout],
outsize = (out_channel, g.num_nodes))
end

Expand All @@ -137,21 +137,21 @@ end

@testset "GATv2Conv" begin
for heads in (1, 2), concat in (true, false)
l = GATv2Conv(in_channel => out_channel, tanh; heads, concat)
l = GATv2Conv(in_channel => out_channel, tanh; heads, concat, dropout=0)
for g in test_graphs
test_layer(l, g, rtol = RTOL_LOW, atol=ATOL_LOW,
exclude_grad_fields = [:negative_slope],
exclude_grad_fields = [:negative_slope, :dropout],
outsize = (concat ? heads * out_channel : out_channel,
g.num_nodes))
end
end

@testset "edge features" begin
ein = 3
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops = false)
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops = false, dropout=0)
g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges))
test_layer(l, g, rtol = RTOL_LOW, atol=ATOL_LOW,
exclude_grad_fields = [:negative_slope],
exclude_grad_fields = [:negative_slope, :dropout],
outsize = (out_channel, g.num_nodes))
end

Expand All @@ -163,15 +163,6 @@ end
l = GATv2Conv((2, 4) => 3, add_self_loops = false, bias = false)
@test length(Flux.params(l)) == 4
end

@testset "edge features" begin
ein = 3
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops = false)
g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges))
test_layer(l, g, rtol = RTOL_LOW, atol=ATOL_LOW,
exclude_grad_fields = [:negative_slope],
outsize = (out_channel, g.num_nodes))
end
end

@testset "GatedGraphConv" begin
Expand Down

0 comments on commit 18c4606

Please sign in to comment.