From 04494b5184f02ef4986334ca8987cfbfe014f4ab Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 15 Nov 2024 21:47:45 -0500 Subject: [PATCH] fix: mark kwargs in functor as leaf (#1085) --- Project.toml | 2 +- docs/src/index.md | 7 ++++--- src/helpers/compact.jl | 22 +++++++++++++++++++--- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 8134038ca..91fe24f93 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.3.0" +version = "1.3.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/index.md b/docs/src/index.md index 74b19619d..68dfe71f9 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -125,7 +125,7 @@ const dev = gpu_device() ::: -## Want XLA Support? +## Want Reactant (XLA) Support? Install the following package: @@ -134,13 +134,14 @@ using Pkg; Pkg.add("Reactant") ``` -Run the following to access a device: +Run the following to access a device (Reactant automatically selects the best backend by +default): :::code-group ```julia [CPU Backend] using Reactant, Lux -Reactant.set_default_backend("cpu") # default +Reactant.set_default_backend("cpu") const dev = reactant_device() ``` diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index 9e287ff2a..43b4b009e 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -32,7 +32,8 @@ is useful when using it with SciML tools which require passing in the parameters If you are passing in kwargs by splatting them, they will be passed as is to the function body. This means if your splatted kwargs contain a lux layer that won't be registered -in the CompactLuxLayer. +in the CompactLuxLayer. Additionally all of the device functions treat these kwargs as +leaves. ## Special Syntax @@ -314,7 +315,13 @@ function initialstates(rng::AbstractRNG, m::CompactLuxLayer) initialstates(rng, m.layers)..., initialstates(rng, m.value_storage)...) length(first(m.stored_kwargs)) == 0 && return base_states return merge( - base_states, (; ₋₋₋kwargs₋₋₋=NamedTuple{m.stored_kwargs[1]}(m.stored_kwargs[2]))) + base_states, + (; + ₋₋₋kwargs₋₋₋=CompactMacroImpl.KwargsStorage( + NamedTuple{m.stored_kwargs[1]}(m.stored_kwargs[2]) + ) + ) + ) end function CompactLuxLayer(dispatch::StaticSymbol, f::F, name::NAME_TYPE, @@ -419,6 +426,7 @@ module CompactMacroImpl using ChainRulesCore: @non_differentiable using ConcreteStructs: @concrete using MacroTools: MacroTools, @capture, combinedef, splitdef +using Functors: Functors using Random: AbstractRNG using Static: static @@ -517,7 +525,9 @@ function supportself(fex::Expr, vars, splatted_kwargs) end for var in splatted_kwargs push!(calls, - :($var = $(safe_getproperty)(getproperty($st, :₋₋₋kwargs₋₋₋), $(Val(var))))) + :($var = $(safe_getproperty)( + getproperty(getproperty($st, :₋₋₋kwargs₋₋₋), :kws), $(Val(var)) + ))) end custom_param && push!(calls, :($(sdef[:args][2]) = $ps)) @@ -631,6 +641,12 @@ function LuxCore.initialstates(rng::AbstractRNG, v::ValueStorage) for (n, fn) in pairs(v.st_init_fns)]) end +@concrete struct KwargsStorage + kws <: NamedTuple +end + +Functors.@leaf KwargsStorage + function kwarg_descriptor(val) val isa NonTrainable && return "@non_trainable($(kwarg_descriptor(val.value)))" val isa Number && return string(val)