-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add WeightNorm
reparametrization
#2550
Conversation
This way we have duplication of the weights. struct WeightNorm{which, dims, L, G}
layer::L
g::G
end
(w::WeightNorm)(x) = transform(w)(x)
function transform(wn::WeightNorm{which, dims}) where {which, dims}
ϵ = eps(eltype(wn.v))
v = getfield(wn.layer, which)
n2 = sum(abs2, v; dims)
w = @. wn.g * v / sqrt(n2 + ϵ)
fields, ctor = Functors.functor(wn.layer)
return ctor(merge(
fields, NamedTuple{(which,)}((w,)),
))
end |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2550 +/- ##
==========================================
+ Coverage 31.95% 32.78% +0.82%
==========================================
Files 34 34
Lines 1987 1998 +11
==========================================
+ Hits 635 655 +20
+ Misses 1352 1343 -9 ☔ View full report in Codecov by Sentry. |
Not yet ready. Needs some adjustments for GPUs |
Added GPU tests as well, making a small test suite which we can exapnd in future to avoid duplication. |
The doctest needs to be fixed. besides that looks good |
All tests now pass. |
Yet another attempt at adding WeightNorm.
Based on different bits found in #2053 [edited!] & #1005 with tests and documentation.
PR Checklist