Skip to content

Commit

Permalink
trying transformerconv
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Sep 30, 2024
1 parent a034753 commit 89fccea
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 1 deletion.
2 changes: 1 addition & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export AGNNConv,
# SAGEConv,
SGConv
# TAGConv,
# TransformerConv
TransformerConv

include("layers/temporalconv.jl")
export TGCN,
Expand Down
93 changes: 93 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,96 @@ function Base.show(io::IO, l::ResGatedGraphConv)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

@concrete struct TransformerConv <: GNNContainerLayer{(:W1, :W2, :W3, :W4, :W5, :W6, :FF, :BN1, :BN2)}
in_dims::NTuple{2, Int}
out_dims::Int
heads::Int
add_self_loops::Bool
concat::Bool
skip_connection::Bool
sqrt_out::Float32
W1
W2
W3
W4
W5
W6
FF
BN1
BN2
end

function TransformerConv(ch::Pair{Int, Int}, args...; kws...)
TransformerConv((ch[1], 0) => ch[2], args...; kws...)
end

function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
heads::Int = 1,
concat::Bool = true,
init_weight = glorot_uniform,
init_bias = zeros32,
add_self_loops::Bool = false,
bias_qkv = true,
bias_root::Bool = true,
root_weight::Bool = true,
gating::Bool = false,
skip_connection::Bool = false,
batch_norm::Bool = false,
ff_channels::Int = 0)

(in, ein), out = ch

if add_self_loops
@assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

W1 = root_weight ?
Dense(in => out * (concat ? heads : 1); use_bias = bias_root, init_weight, init_bias) : nothing
W2 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
W3 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
W4 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
out_mha = out * (concat ? heads : 1)
W5 = gating ? Dense(3 * out_mha => 1, sigmoid; use_bias = false, init_weight, init_bias) : nothing
W6 = ein > 0 ? Dense(ein => out * heads; use_bias = bias_qkv, init_weight, init_bias) : nothing
FF = ff_channels > 0 ?
Chain(Dense(out_mha => ff_channels, relu; init_weight, init_bias),
Dense(ff_channels => out_mha; init_weight, init_bias)) : nothing
BN1 = batch_norm ? BatchNorm(out_mha) : nothing
BN2 = (batch_norm && ff_channels > 0) ? BatchNorm(out_mha) : nothing

return TransformerConv((in, ein), out, heads, add_self_loops, concat,
skip_connection, Float32(out), W1, W2, W3, W4, W5, W6, FF, BN1, BN2)
end

LuxCore.outputsize(l::TransformerConv) = (l.out_dims,)

function (l::TransformerConv)(g, x, ps, st)
l(g, x, nothing, ps, st)
end

function (l::TransformerConv)(g, x, e, ps, st)
W1 = l.W1 === nothing ? nothing :
StatefulLuxLayer{true}(l.W1, ps.W1, _getstate(st, :W1))
W2 = StatefulLuxLayer{true}(l.W2, ps.W2, _getstate(st, :W2))
W3 = StatefulLuxLayer{true}(l.W3, ps.W3, _getstate(st, :W3))
W4 = StatefulLuxLayer{true}(l.W4, ps.W4, _getstate(st, :W4))
W5 = l.W5 === nothing ? nothing :
StatefulLuxLayer{true}(l.W5, ps.W5, _getstate(st, :W5))
W6 = l.W6 === nothing ? nothing :
StatefulLuxLayer{true}(l.W6, ps.W6, _getstate(st, :W6))
FF = l.FF === nothing ? nothing :
StatefulLuxLayer{true}(l.FF, ps.FF, _getstate(st, :FF))
BN1 = l.BN1 === nothing ? nothing :
StatefulLuxLayer{true}(l.BN1, ps.BN1, _getstate(st, :BN1))
BN2 = l.BN2 === nothing ? nothing :
StatefulLuxLayer{true}(l.BN2, ps.BN2, _getstate(st, :BN2))
m = (; W1, W2, W3, W4, W5, W6, FF, BN1, BN2, l.sqrt_out,
l.heads, l.concat, l.skip_connection)
return GNNlib.transformer_conv(m, g, x, e), st
end

function Base.show(io::IO, l::TransformerConv)
(in, ein), out = (l.in_dims, l.out_dims)
print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")
end
12 changes: 12 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,16 @@
l = ResGatedGraphConv(in_dims => out_dims, tanh)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end

@testset "TransformerConv" begin
x = randn(rng, Float32, 6, 10)
ein = 2
e = randn(rng, Float32, ein, g.num_edges)

l = TransformerConv((6, ein) => 8, heads = 2, gating = true, bias_qkv = true)
test_lux_layer(rng, l, g, x, outputsize = (8,), e = e, container = true)

l = TransformerConv((6, ein) => 8, heads = 2, concat = false, skip_connection = true)
test_lux_layer(rng, l, g, x, outputsize = (8,), e = e, container = true)
end
end
140 changes: 140 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ArgTools = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
ArnoldiMethod = "ec485272-7323-5ecc-a04f-4719b315124d"
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Baselet = "9718e550-a3fa-408a-8086-8db961cd8217"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
CompilerSupportLibraries_jll = "e66e0078-7015-5450-92f7-15fbd957f2ae"
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DataValueInterfaces = "e2d170a0-9d28-54be-80f0-106bbe20a464"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DefineSingletons = "244e2a9f-e319-4986-a169-4d1fe445cd52"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
FLoopsBase = "b9860ae5-e623-471e-878b-f6a53c775ea6"
FileWatching = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
Inflate = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
InitialValues = "22cec73e-a1b8-11e9-2c92-598750a2cf9c"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
JLLWrappers = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
JuliaVariables = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVMExtra_jll = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
LibCURL = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
LibCURL_jll = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433"
LibGit2_jll = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
LibSSH2_jll = "29816b5a-b9ab-546f-933c-edad1886dfa8"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MbedTLS_jll = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
MicroCollections = "128add7d-3638-4c79-886c-908ea0c25c34"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
MozillaCACerts_jll = "14a3606d-f60d-562e-9121-12d972cd8159"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
NameResolution = "71a1bf82-56d0-4bbc-8a3c-48b961074391"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
NetworkOptions = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
OpenBLAS_jll = "4536629a-c528-5b80-bd46-f80d51c5b363"
OpenLibm_jll = "05823500-19ac-5b8b-9628-191a04bc5112"
OpenSpecFun_jll = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
PrettyPrint = "8162dcfd-2161-5ef2-ae6c-7681170c5f98"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
SortingAlgorithms = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
SplittablesBase = "171d559e-b47b-412a-8079-5efa626c420e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
SuiteSparse_jll = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f"
UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
Zlib_jll = "83775a58-1f1d-513f-b197-d71354ab007a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
libblastrampoline_jll = "8e850b90-86db-534c-a0d3-1478176c7d93"
nghttp2_jll = "8e850ede-7688-5339-a07c-302acd2aaf8d"
p7zip_jll = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

0 comments on commit 89fccea

Please sign in to comment.