Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
add stateless_apply for AbstractExplicitContainerLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianM-C committed Feb 24, 2024
1 parent 97a404c commit 596daa6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
14 changes: 13 additions & 1 deletion src/LuxCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ Simply calls `model(x, ps, st)`
apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st)

"""
stateless_apply(model, x, ps, st)
stateless_apply(model, x, ps)
Calls `apply` and only returns the first argument.
"""
Expand Down Expand Up @@ -188,6 +188,18 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers}
return sum(statelength, getfield.((l,), layers))
end

function stateless_apply(
model::AbstractExplicitContainerLayer{layers}, x, ps) where {layers}
if length(layers) == 1
layer_names = keys(getfield(model, layers[1]))
else
layer_names = layers
end
st = NamedTuple{layer_names}(NamedTuple() for _ in layer_names)

return first(apply(model, x, ps, st))
end

# Make AbstractExplicit Layers Functor Compatible
function Functors.functor(::Type{<:AbstractExplicitContainerLayer{layers}},
x) where {layers}
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ end
@test LuxCore.apply(model, x, ps, st) == model(x, ps, st)

@test LuxCore.stateless_apply(model, x, ps) ==
first(LuxCore.apply(model, x, ps, NamedTuple()))
first(LuxCore.apply(model, x, ps, st))

@test_nowarn println(model)

Expand All @@ -110,7 +110,7 @@ end
@test LuxCore.apply(model, x, ps, st) == model(x, ps, st)

@test LuxCore.stateless_apply(model, x, ps) ==
first(LuxCore.apply(model, x, ps, NamedTuple()))
first(LuxCore.apply(model, x, ps, st))

@test_nowarn println(model)
end
Expand Down

0 comments on commit 596daa6

Please sign in to comment.