Skip to content

Commit

Permalink
Adjust for GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Dec 12, 2024
1 parent 52d0a7a commit ed7f51a
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,12 @@ See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`Laye
"""
hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine

struct WeightNorm{which, dims, L, G}
struct WeightNorm{L, G, D}
layer::L
g::G

which::Symbol
dims::D
end
@layer WeightNorm

Expand Down Expand Up @@ -624,7 +627,7 @@ function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L
end

g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x)))
WeightNorm{which, dims, L, typeof(g)}(layer, g)
WeightNorm(layer, g, which, dims)

Check warning on line 630 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L629-L630

Added lines #L629 - L630 were not covered by tests
end

(w::WeightNorm)(x) = reparametrize(w)(x)

Check warning on line 633 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L633

Added line #L633 was not covered by tests
Expand All @@ -634,22 +637,22 @@ end
Apply `WeightNorm` reparametrization and return underlying `layer`.
"""
function reparametrize(wn::WeightNorm{which, dims}) where {which, dims}
function reparametrize(wn::WeightNorm)
ϵ = eps(eltype(wn.g))
v = getfield(wn.layer, which)
n2 = sum(abs2, v; dims)
v = getfield(wn.layer, wn.which)
n2 = sum(abs2, v; wn.dims)
w = @. wn.g * v / sqrt(n2 + ϵ)

Check warning on line 644 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L640-L644

Added lines #L640 - L644 were not covered by tests

fields, ctor = Functors.functor(wn.layer)
return ctor(merge(

Check warning on line 647 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L646-L647

Added lines #L646 - L647 were not covered by tests
fields, NamedTuple{(which,)}((w,)),
fields, NamedTuple{(wn.which,)}((w,)),
))
end

function Base.show(io::IO, w::WeightNorm{which, dims}) where {which, dims}
function Base.show(io::IO, w::WeightNorm)
print(io, "WeightNorm(")
Base.show(io, w.layer)
print(io, ", :", which, "; dims=", dims)
print(io, ", :", w.which, "; dims=", w.dims)
print(io, ")")

Check warning on line 656 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L652-L656

Added lines #L652 - L656 were not covered by tests
end

Expand Down

0 comments on commit ed7f51a

Please sign in to comment.