diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 9a54ff7c7..1254fb284 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -844,7 +844,6 @@ function Base.show(io::IO, l::ResGatedGraphConv) l.use_bias || print(io, ", use_bias=false") print(io, ")") end - @concrete struct TransformerConv <: GNNContainerLayer{(:W1, :W2, :W3, :W4, :W5, :W6, :FF, :BN1, :BN2)} in_dims::NTuple{2, Int} out_dims::Int @@ -853,15 +852,15 @@ end concat::Bool skip_connection::Bool sqrt_out::Float32 - W1 - W2 - W3 - W4 - W5 - W6 - FF - BN1 - BN2 + W1 + W2 + W3 + W4 + W5 + W6 + FF + BN1 + BN2 end function TransformerConv(ch::Pair{Int, Int}, args...; kws...) @@ -912,18 +911,6 @@ function (l::TransformerConv)(g, x, ps, st) l(g, x, nothing, ps, st) end -function LuxCore.parameterlength(l::TransformerConv) - n = parameterlength(l.W2) + parameterlength(l.W3) + - parameterlength(l.W4) + (l.W6 === nothing ? 0 : parameterlength(l.W6)) - - n += l.W1 === nothing ? 0 : parameterlength(l.W1) - n += l.W5 === nothing ? 0 : parameterlength(l.W5) - n += l.FF === nothing ? 0 : parameterlength(l.FF) - n += l.BN1 === nothing ? 0 : parameterlength(l.BN1) - n += l.BN2 === nothing ? 0 : parameterlength(l.BN2) - return n -end - function (l::TransformerConv)(g, x, e, ps, st) W1 = l.W1 === nothing ? nothing : StatefulLuxLayer{true}(l.W1, ps.W1, _getstate(st, :W1)) @@ -941,10 +928,32 @@ function (l::TransformerConv)(g, x, e, ps, st) BN2 = l.BN2 === nothing ? nothing : StatefulLuxLayer{true}(l.BN2, ps.BN2, _getstate(st, :BN2)) m = (; W1, W2, W3, W4, W5, W6, FF, BN1, BN2, l.sqrt_out, - l.heads, l.concat, l.skip_connection, l.add_self_loops) + l.heads, l.concat, l.skip_connection, l.add_self_loops, l.in_dims) return GNNlib.transformer_conv(m, g, x, e), st end +function LuxCore.parameterlength(l::TransformerConv) + n = parameterlength(l.W1) + parameterlength(l.W2) + + parameterlength(l.W3) + parameterlength(l.W4) + + parameterlength(l.W5) + parameterlength(l.W6) + + n += l.FF === nothing ? 0 : parameterlength(l.FF) + n += l.BN1 === nothing ? 0 : parameterlength(l.BN1) + n += l.BN2 === nothing ? 0 : parameterlength(l.BN2) + return n +end + +function LuxCore.statelength(l::TransformerConv) + n = statelength(l.W1) + statelength(l.W2) + + statelength(l.W3) + statelength(l.W4) + + statelength(l.W5) + statelength(l.W6) + + n += l.FF === nothing ? 0 : statelength(l.FF) + n += l.BN1 === nothing ? 0 : statelength(l.BN1) + n += l.BN2 === nothing ? 0 : statelength(l.BN2) + return n +end + function Base.show(io::IO, l::TransformerConv) (in, ein), out = (l.in_dims, l.out_dims) print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")