From ab60da7171e2279cb3315896ff25b2cff2dff762 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 13:58:08 -0400 Subject: [PATCH] feat: update to support Lux 1.0 --- Project.toml | 8 +++----- src/GraphNetCore.jl | 6 ++++-- src/graph_net_blocks.jl | 16 +++++++--------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index bdcb0e7..a1ef77e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GraphNetCore" uuid = "7809f980-de1b-4f9a-8451-85f041491431" authors = ["JT "] -version = "0.3.1" +version = "0.3.2" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -11,7 +11,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -27,15 +26,14 @@ DataFrames = "1.6" ForwardDiff = "0.10" JLD2 = "0.4" KernelAbstractions = "0.9" -Lux = "0.5" -LuxCUDA = "0.3" +Lux = "1" NNlib = "0.9" Random = "1" Statistics = "1" +Test = "1" Tullio = "0.3.7" Zygote = "0.6" cuDNN = "1.3" -Test = "1" julia = "1.10" [extras] diff --git a/src/GraphNetCore.jl b/src/GraphNetCore.jl index a2a7737..146c6cf 100644 --- a/src/GraphNetCore.jl +++ b/src/GraphNetCore.jl @@ -5,11 +5,13 @@ module GraphNetCore -using CUDA -using Lux, LuxCUDA +using CUDA, cuDNN +using Lux using Tullio using Random +const NAME_TYPE = Union{Nothing, String, Symbol} + include("utils.jl") include("normaliser.jl") include("graph_network.jl") diff --git a/src/graph_net_blocks.jl b/src/graph_net_blocks.jl index cbf3f05..64f7b48 100644 --- a/src/graph_net_blocks.jl +++ b/src/graph_net_blocks.jl @@ -3,13 +3,12 @@ # Licensed under the MIT license. See LICENSE file in the project root for details. # -struct Encoder{T <: NamedTuple, N <: Lux.NAME_TYPE} <: - Lux.AbstractExplicitContainerLayer{(:layers,)} +struct Encoder{T <: NamedTuple, N <: NAME_TYPE} <: Lux.AbstractLuxWrapperLayer{:layers} layers::T name::N end -function Encoder(node_model, edge_model; name::Lux.NAME_TYPE = nothing) +function Encoder(node_model, edge_model; name::NAME_TYPE = nothing) fields = (Symbol("node_model_fn"), Symbol("edge_model_fn")) return Encoder(NamedTuple{fields}((node_model, edge_model)), name) @@ -28,13 +27,13 @@ function encode!( return update_features!(graph; nf = nf, ef = ef), new_st end -struct Processor{T <: NamedTuple, N <: Lux.NAME_TYPE} <: - Lux.AbstractExplicitContainerLayer{(:layers,)} +struct Processor{T <: NamedTuple, N <: NAME_TYPE} <: + Lux.AbstractLuxWrapperLayer{:layers} layers::T name::N end -function Processor(node_model, edge_model; name::Lux.NAME_TYPE = nothing) +function Processor(node_model, edge_model; name::NAME_TYPE = nothing) fields = (Symbol("node_model_fn"), Symbol("edge_model_fn")) return Processor(NamedTuple{fields}((node_model, edge_model)), name) @@ -68,13 +67,12 @@ end return nl(features, ps, st) end -struct Decoder{T <: NamedTuple, N <: Lux.NAME_TYPE} <: - Lux.AbstractExplicitContainerLayer{(:layers,)} +struct Decoder{T <: NamedTuple, N <: NAME_TYPE} <: Lux.AbstractLuxWrapperLayer{:layers} layers::T name::N end -function Decoder(model; name::Lux.NAME_TYPE = nothing) +function Decoder(model; name::NAME_TYPE = nothing) fields = (Symbol("model"),) return Decoder(NamedTuple{fields}((model,)), name)