Skip to content

Commit

Permalink
Rename 'transform' to 'reparametrize' & other minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Dec 12, 2024
1 parent 780a138 commit 52d0a7a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
siamese_contrastive_loss,
squared_hinge_loss,
tversky_loss,
remove_weight_norms,
))

include("gradient.jl")
Expand Down
12 changes: 8 additions & 4 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,17 @@ function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L
end

g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x)))
x ./= g # Store `v` in the original weights.
WeightNorm{which, dims, L, typeof(g)}(layer, g)

Check warning on line 627 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L626-L627

Added lines #L626 - L627 were not covered by tests
end

(w::WeightNorm)(x) = transform(w)(x)
(w::WeightNorm)(x) = reparametrize(w)(x)

Check warning on line 630 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L630

Added line #L630 was not covered by tests

function transform(wn::WeightNorm{which, dims}) where {which, dims}
"""
reparametrize(wn::WeightNorm)
Apply `WeightNorm` reparametrization and return underlying `layer`.
"""
function reparametrize(wn::WeightNorm{which, dims}) where {which, dims}
ϵ = eps(eltype(wn.g))
v = getfield(wn.layer, which)
n2 = sum(abs2, v; dims)
Expand Down Expand Up @@ -681,4 +685,4 @@ Chain(
) # Total: 4 arrays, 22 parameters, 392 bytes.
```
"""
remove_weight_norms(x) = fmap(transform, x; exclude=l -> l isa WeightNorm)
remove_weight_norms(x) = fmap(reparametrize, x; exclude=l -> l isa WeightNorm)

Check warning on line 688 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L688

Added line #L688 was not covered by tests
19 changes: 15 additions & 4 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,10 @@ end
v = mn.layer.weight
ϵ = eps(eltype(v))
n2 = sum(abs2, v; dims=2)
v = v ./ sqrt.(n2 .+ ϵ)

@test (og.weight .* v ./ sqrt.(n2 .+ ϵ)) g.g
@test (og.weight .* mn.g ./ n2 .- mn.g .* g.g .* v ./ n2.^2) g.layer.weight atol=1f-6
@test (og.weight .* v) g.g
@test (og.weight .* mn.g .- mn.g .* g.g .* v) g.layer.weight atol=1f-6

# Test WeightNorm removal.

Expand All @@ -484,14 +485,24 @@ end

c = Chain(
WeightNorm(Conv((3,), 1 => 2)),
WeightNorm(Conv((3,), 2 => 2)),
Conv((3,), 2 => 2),
WeightNorm(Conv((3,), 2 => 3)),
x -> reshape(x, 18, :),
WeightNorm(Dense(18, 4)),
Dense(4, 1),
)
@test c[1] isa WeightNorm
@test c[2] isa WeightNorm
@test c[2] isa Conv
@test c[3] isa WeightNorm
@test c[5] isa WeightNorm
@test c[6] isa Dense

oc = Flux.remove_weight_norms(c)
@test oc[1] isa Conv
@test oc[2] isa Conv
@test oc[3] isa Conv
@test oc[5] isa Dense
@test oc[6] isa Dense

x = rand(Float32, 12, 1, 1)
@test c(x) oc(x)
Expand Down

0 comments on commit 52d0a7a

Please sign in to comment.