diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 2a9cc5852..1aaa70b6b 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -38,7 +38,7 @@ export AGNNConv, # SAGEConv, SGConv # TAGConv, - # TransformerConv + TransformerConv include("layers/temporalconv.jl") export TGCN, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index fbf7ad7c2..484c5d9a2 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -844,3 +844,96 @@ function Base.show(io::IO, l::ResGatedGraphConv) l.use_bias || print(io, ", use_bias=false") print(io, ")") end + +@concrete struct TransformerConv <: GNNContainerLayer{(:W1, :W2, :W3, :W4, :W5, :W6, :FF, :BN1, :BN2)} + in_dims::NTuple{2, Int} + out_dims::Int + heads::Int + add_self_loops::Bool + concat::Bool + skip_connection::Bool + sqrt_out::Float32 + W1 + W2 + W3 + W4 + W5 + W6 + FF + BN1 + BN2 +end + +function TransformerConv(ch::Pair{Int, Int}, args...; kws...) + TransformerConv((ch[1], 0) => ch[2], args...; kws...) +end + +function TransformerConv(ch::Pair{NTuple{2, Int}, Int}; + heads::Int = 1, + concat::Bool = true, + init_weight = glorot_uniform, + init_bias = zeros32, + add_self_loops::Bool = false, + bias_qkv = true, + bias_root::Bool = true, + root_weight::Bool = true, + gating::Bool = false, + skip_connection::Bool = false, + batch_norm::Bool = false, + ff_channels::Int = 0) + + (in, ein), out = ch + + if add_self_loops + @assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + W1 = root_weight ? + Dense(in => out * (concat ? heads : 1); use_bias = bias_root, init_weight, init_bias) : nothing + W2 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias) + W3 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias) + W4 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias) + out_mha = out * (concat ? heads : 1) + W5 = gating ? Dense(3 * out_mha => 1, sigmoid; use_bias = false, init_weight, init_bias) : nothing + W6 = ein > 0 ? Dense(ein => out * heads; use_bias = bias_qkv, init_weight, init_bias) : nothing + FF = ff_channels > 0 ? + Chain(Dense(out_mha => ff_channels, relu; init_weight, init_bias), + Dense(ff_channels => out_mha; init_weight, init_bias)) : nothing + BN1 = batch_norm ? BatchNorm(out_mha) : nothing + BN2 = (batch_norm && ff_channels > 0) ? BatchNorm(out_mha) : nothing + + return TransformerConv((in, ein), out, heads, add_self_loops, concat, + skip_connection, Float32(√out), W1, W2, W3, W4, W5, W6, FF, BN1, BN2) +end + +LuxCore.outputsize(l::TransformerConv) = (l.out_dims,) + +function (l::TransformerConv)(g, x, ps, st) + l(g, x, nothing, ps, st) +end + +function (l::TransformerConv)(g, x, e, ps, st) + W1 = l.W1 === nothing ? nothing : + StatefulLuxLayer{true}(l.W1, ps.W1, _getstate(st, :W1)) + W2 = StatefulLuxLayer{true}(l.W2, ps.W2, _getstate(st, :W2)) + W3 = StatefulLuxLayer{true}(l.W3, ps.W3, _getstate(st, :W3)) + W4 = StatefulLuxLayer{true}(l.W4, ps.W4, _getstate(st, :W4)) + W5 = l.W5 === nothing ? nothing : + StatefulLuxLayer{true}(l.W5, ps.W5, _getstate(st, :W5)) + W6 = l.W6 === nothing ? nothing : + StatefulLuxLayer{true}(l.W6, ps.W6, _getstate(st, :W6)) + FF = l.FF === nothing ? nothing : + StatefulLuxLayer{true}(l.FF, ps.FF, _getstate(st, :FF)) + BN1 = l.BN1 === nothing ? nothing : + StatefulLuxLayer{true}(l.BN1, ps.BN1, _getstate(st, :BN1)) + BN2 = l.BN2 === nothing ? nothing : + StatefulLuxLayer{true}(l.BN2, ps.BN2, _getstate(st, :BN2)) + m = (; W1, W2, W3, W4, W5, W6, FF, BN1, BN2, l.sqrt_out, + l.heads, l.concat, l.skip_connection) + return GNNlib.transformer_conv(m, g, x, e), st +end + +function Base.show(io::IO, l::TransformerConv) + (in, ein), out = (l.in_dims, l.out_dims) + print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))") +end \ No newline at end of file diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 6541dfe0c..f87b1b1f3 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -134,4 +134,16 @@ l = ResGatedGraphConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end + + @testset "TransformerConv" begin + x = randn(rng, Float32, 6, 10) + ein = 2 + e = randn(rng, Float32, ein, g.num_edges) + + l = TransformerConv((6, ein) => 8, heads = 2, gating = true, bias_qkv = true) + test_lux_layer(rng, l, g, x, outputsize = (8,), e = e, container = true) + + l = TransformerConv((6, ein) => 8, heads = 2, concat = false, skip_connection = true) + test_lux_layer(rng, l, g, x, outputsize = (8,), e = e, container = true) + end end diff --git a/Project.toml b/Project.toml new file mode 100644 index 000000000..a89e99a54 --- /dev/null +++ b/Project.toml @@ -0,0 +1,140 @@ +[deps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +ArgTools = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +ArnoldiMethod = "ec485272-7323-5ecc-a04f-4719b315124d" +Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +Baselet = "9718e550-a3fa-408a-8086-8db961cd8217" +CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +CompilerSupportLibraries_jll = "e66e0078-7015-5450-92f7-15fbd957f2ae" +CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +DataValueInterfaces = "e2d170a0-9d28-54be-80f0-106bbe20a464" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +DefineSingletons = "244e2a9f-e319-4986-a169-4d1fe445cd52" +DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" +FLoopsBase = "b9860ae5-e623-471e-878b-f6a53c775ea6" +FileWatching = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Future = "9fa8497b-333b-5362-9e8d-4d0656e87820" +GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" +GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" +Inflate = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +InitialValues = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" +IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" +JLLWrappers = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +JuliaVariables = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" +LLVMExtra_jll = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +LibCURL = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +LibCURL_jll = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433" +LibGit2_jll = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +LibSSH2_jll = "29816b5a-b9ab-546f-933c-edad1886dfa8" +Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +MbedTLS_jll = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +MicroCollections = "128add7d-3638-4c79-886c-908ea0c25c34" +Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +Mmap = "a63ad114-7e13-5084-954f-fe012c677804" +MozillaCACerts_jll = "14a3606d-f60d-562e-9121-12d972cd8159" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +NameResolution = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +NetworkOptions = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +OpenBLAS_jll = "4536629a-c528-5b80-bd46-f80d51c5b363" +OpenLibm_jll = "05823500-19ac-5b8b-9628-191a04bc5112" +OpenSpecFun_jll = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" +PrettyPrint = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383" +ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" +SortingAlgorithms = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +SplittablesBase = "171d559e-b47b-412a-8079-5efa626c420e" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +SuiteSparse_jll = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" +UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f" +UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" +Zlib_jll = "83775a58-1f1d-513f-b197-d71354ab007a" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +libblastrampoline_jll = "8e850b90-86db-534c-a0d3-1478176c7d93" +nghttp2_jll = "8e850ede-7688-5339-a07c-302acd2aaf8d" +p7zip_jll = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"