Skip to content

Commit

Permalink
Update conv.jl: parameter length fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky authored Sep 30, 2024
1 parent be04df5 commit 8b9751b
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,18 @@ function (l::TransformerConv)(g, x, ps, st)
l(g, x, nothing, ps, st)
end

function LuxCore.parameterlength(l::TransformerConv)
n = parameterlength(l.W2) + parameterlength(l.W3) +
parameterlength(l.W4) + (l.W6 === nothing ? 0 : parameterlength(l.W6))

n += l.W1 === nothing ? 0 : parameterlength(l.W1)
n += l.W5 === nothing ? 0 : parameterlength(l.W5)
n += l.FF === nothing ? 0 : parameterlength(l.FF)
n += l.BN1 === nothing ? 0 : parameterlength(l.BN1)
n += l.BN2 === nothing ? 0 : parameterlength(l.BN2)
return n
end

function (l::TransformerConv)(g, x, e, ps, st)
W1 = l.W1 === nothing ? nothing :
StatefulLuxLayer{true}(l.W1, ps.W1, _getstate(st, :W1))
Expand Down

0 comments on commit 8b9751b

Please sign in to comment.