diff --git a/src/LuxCore.jl b/src/LuxCore.jl index ccc8b18..40742f3 100644 --- a/src/LuxCore.jl +++ b/src/LuxCore.jl @@ -132,7 +132,7 @@ apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st) Calls `apply` and only returns the first argument. """ function stateless_apply(model::AbstractExplicitLayer, x, ps, st) - first(apply(model, x, ps, st)) + return first(apply(model, x, ps, st)) end function stateless_apply(model, x, ps, st) diff --git a/test/runtests.jl b/test/runtests.jl index e686463..80979ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,9 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + @test LuxCore.stateless_apply(model, x, ps, st) == + first(LuxCore.apply(model, x, ps, st)) + @test_nowarn println(model) end @@ -88,6 +91,9 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + @test LuxCore.stateless_apply(model, x, ps, st) == + first(LuxCore.apply(model, x, ps, st)) + @test_nowarn println(model) model = Chain2(Dense(5, 5), Dense(5, 6)) @@ -103,6 +109,9 @@ end @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + @test LuxCore.stateless_apply(model, x, ps, st) == + first(LuxCore.apply(model, x, ps, st)) + @test_nowarn println(model) end