diff --git a/src/LuxCore.jl b/src/LuxCore.jl index f4cd971..ae88919 100644 --- a/src/LuxCore.jl +++ b/src/LuxCore.jl @@ -127,7 +127,7 @@ Simply calls `model(x, ps, st)` apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) """ - stateless_apply(model, x, ps, st) + stateless_apply(model, x, ps) Calls `apply` and only returns the first argument. """ @@ -188,6 +188,18 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} return sum(statelength, getfield.((l,), layers)) end +function stateless_apply( + model::AbstractExplicitContainerLayer{layers}, x, ps) where {layers} + if length(layers) == 1 + layer_names = keys(getfield(model, layers[1])) + else + layer_names = layers + end + st = NamedTuple{layer_names}(NamedTuple() for _ in layer_names) + + return first(apply(model, x, ps, st)) +end + # Make AbstractExplicit Layers Functor Compatible function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}}, x) where {layers} diff --git a/test/runtests.jl b/test/runtests.jl index 65e309a..6a80691 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -92,7 +92,7 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) @test LuxCore.stateless_apply(model, x, ps) == - first(LuxCore.apply(model, x, ps, NamedTuple())) + first(LuxCore.apply(model, x, ps, st)) @test_nowarn println(model) @@ -110,7 +110,7 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) @test LuxCore.stateless_apply(model, x, ps) == - first(LuxCore.apply(model, x, ps, NamedTuple())) + first(LuxCore.apply(model, x, ps, st)) @test_nowarn println(model) end