diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b06c6a146a..33350a904c 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -569,9 +569,12 @@ See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`Laye """ hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine -struct WeightNorm{which, dims, L, G} +struct WeightNorm{L, G, D} layer::L g::G + + which::Symbol + dims::D end @layer WeightNorm @@ -624,7 +627,7 @@ function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L end g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x))) - WeightNorm{which, dims, L, typeof(g)}(layer, g) + WeightNorm(layer, g, which, dims) end (w::WeightNorm)(x) = reparametrize(w)(x) @@ -634,22 +637,22 @@ end Apply `WeightNorm` reparametrization and return underlying `layer`. """ -function reparametrize(wn::WeightNorm{which, dims}) where {which, dims} +function reparametrize(wn::WeightNorm) ϵ = eps(eltype(wn.g)) - v = getfield(wn.layer, which) - n2 = sum(abs2, v; dims) + v = getfield(wn.layer, wn.which) + n2 = sum(abs2, v; wn.dims) w = @. wn.g * v / sqrt(n2 + ϵ) fields, ctor = Functors.functor(wn.layer) return ctor(merge( - fields, NamedTuple{(which,)}((w,)), + fields, NamedTuple{(wn.which,)}((w,)), )) end -function Base.show(io::IO, w::WeightNorm{which, dims}) where {which, dims} +function Base.show(io::IO, w::WeightNorm) print(io, "WeightNorm(") Base.show(io, w.layer) - print(io, ", :", which, "; dims=", dims) + print(io, ", :", w.which, "; dims=", w.dims) print(io, ")") end