Skip to content

Commit

Permalink
Update conv.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky authored Sep 30, 2024
1 parent 8b9751b commit a6a11bd
Showing 1 changed file with 32 additions and 23 deletions.
55 changes: 32 additions & 23 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,6 @@ function Base.show(io::IO, l::ResGatedGraphConv)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

@concrete struct TransformerConv <: GNNContainerLayer{(:W1, :W2, :W3, :W4, :W5, :W6, :FF, :BN1, :BN2)}
in_dims::NTuple{2, Int}
out_dims::Int
Expand All @@ -853,15 +852,15 @@ end
concat::Bool
skip_connection::Bool
sqrt_out::Float32
W1
W2
W3
W4
W5
W6
FF
BN1
BN2
W1
W2
W3
W4
W5
W6
FF
BN1
BN2
end

function TransformerConv(ch::Pair{Int, Int}, args...; kws...)
Expand Down Expand Up @@ -912,18 +911,6 @@ function (l::TransformerConv)(g, x, ps, st)
l(g, x, nothing, ps, st)
end

function LuxCore.parameterlength(l::TransformerConv)
n = parameterlength(l.W2) + parameterlength(l.W3) +
parameterlength(l.W4) + (l.W6 === nothing ? 0 : parameterlength(l.W6))

n += l.W1 === nothing ? 0 : parameterlength(l.W1)
n += l.W5 === nothing ? 0 : parameterlength(l.W5)
n += l.FF === nothing ? 0 : parameterlength(l.FF)
n += l.BN1 === nothing ? 0 : parameterlength(l.BN1)
n += l.BN2 === nothing ? 0 : parameterlength(l.BN2)
return n
end

function (l::TransformerConv)(g, x, e, ps, st)
W1 = l.W1 === nothing ? nothing :
StatefulLuxLayer{true}(l.W1, ps.W1, _getstate(st, :W1))
Expand All @@ -941,10 +928,32 @@ function (l::TransformerConv)(g, x, e, ps, st)
BN2 = l.BN2 === nothing ? nothing :
StatefulLuxLayer{true}(l.BN2, ps.BN2, _getstate(st, :BN2))
m = (; W1, W2, W3, W4, W5, W6, FF, BN1, BN2, l.sqrt_out,
l.heads, l.concat, l.skip_connection, l.add_self_loops)
l.heads, l.concat, l.skip_connection, l.add_self_loops, l.in_dims)
return GNNlib.transformer_conv(m, g, x, e), st
end

function LuxCore.parameterlength(l::TransformerConv)
n = parameterlength(l.W1) + parameterlength(l.W2) +
parameterlength(l.W3) + parameterlength(l.W4) +
parameterlength(l.W5) + parameterlength(l.W6)

n += l.FF === nothing ? 0 : parameterlength(l.FF)
n += l.BN1 === nothing ? 0 : parameterlength(l.BN1)
n += l.BN2 === nothing ? 0 : parameterlength(l.BN2)
return n
end

function LuxCore.statelength(l::TransformerConv)
n = statelength(l.W1) + statelength(l.W2) +
statelength(l.W3) + statelength(l.W4) +
statelength(l.W5) + statelength(l.W6)

n += l.FF === nothing ? 0 : statelength(l.FF)
n += l.BN1 === nothing ? 0 : statelength(l.BN1)
n += l.BN2 === nothing ? 0 : statelength(l.BN2)
return n
end

function Base.show(io::IO, l::TransformerConv)
(in, ein), out = (l.in_dims, l.out_dims)
print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")
Expand Down

0 comments on commit a6a11bd

Please sign in to comment.