Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
rename getstate to _getstate
Browse files Browse the repository at this point in the history
Apply suggestions from code review

Co-authored-by: Avik Pal <[email protected]>
  • Loading branch information
SebastianM-C and avik-pal committed Feb 24, 2024
1 parent 171ece8 commit a777715
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions src/LuxCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ function initialstates(rng::AbstractRNG, l)
throw(MethodError(initialstates, (rng, l)))
end

getstate(::AbstractExplicitLayer) = NamedTuple()
getstate(l::NamedTuple) = NamedTuple{keys(l)}(map(getstate, l))
_getstate(::AbstractExplicitLayer) = NamedTuple()
function _getstate(l::NamedTuple{fields}) where {fields}
return NamedTuple{fields}(map(_getstate, values(l)))
end

"""
parameterlength(layer)
Expand Down Expand Up @@ -135,7 +137,7 @@ apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st)
Calls `apply` and only returns the first argument.
"""
function stateless_apply(model::AbstractExplicitLayer, x, ps)
return first(apply(model, x, ps, NamedTuple()))
return first(apply(model, x, ps, _getstate(model)))
end

"""
Expand Down Expand Up @@ -191,16 +193,9 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers}
return sum(statelength, getfield.((l,), layers))
end

function getstate(l::AbstractExplicitContainerLayer{layers}) where {layers}
length(layers) == 1 && return getstate(getfield(l, layers[1]))
return NamedTuple{layers}(getstate.(getfield.((l,), layers)))
end

function stateless_apply(
model::AbstractExplicitContainerLayer{layers}, x, ps) where {layers}
st = getstate(model)

return first(apply(model, x, ps, st))
function _getstate(l::AbstractExplicitContainerLayer{layers}) where {layers}
length(layers) == 1 && return _getstate(getfield(l, length(layers)))
return NamedTuple{layers}(_getstate.(getfield.((l,), layers)))
end

# Make AbstractExplicit Layers Functor Compatible
Expand Down

0 comments on commit a777715

Please sign in to comment.