Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Norm Layers, Again #1509

Closed
wants to merge 49 commits into from
Closed
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2cd59b4
clean up the history
Feb 16, 2021
0ef033e
bn on cudnn
Feb 16, 2021
1e67700
rm extra utils
Feb 16, 2021
0dbeb39
use simpler test suite first
Feb 16, 2021
04ca0a8
git fixes
Feb 16, 2021
31f7000
refactor norm_forward
Feb 18, 2021
262112a
simplify tests
Feb 20, 2021
c61457c
typose
Feb 20, 2021
0432b82
check reduce for batch
Feb 20, 2021
77a8e87
use normconfig
Feb 21, 2021
d0961a7
use normconfig for other layers
Feb 21, 2021
b091a80
typo
Feb 21, 2021
680e64f
backwards
Feb 21, 2021
4a26559
first pass
Mar 1, 2021
1c38029
Update src/layers/normalise.jl
DhairyaLGandhi Mar 7, 2021
137acf9
don't use single instance in bn
Mar 8, 2021
db559c5
use prev stats to track
Mar 8, 2021
f5c3641
Merge branch 'dg/instance2' of https://github.com/FluxML/Flux.jl into…
Mar 8, 2021
fcea841
track stats tests
Mar 8, 2021
647dcd9
fix BN(3)
Mar 9, 2021
561c94f
check instance norm tests
Mar 10, 2021
332b13b
clean up instance norm
Mar 12, 2021
30d6542
dont reshape eagerly
Mar 12, 2021
1d99fd6
use mean instead of channel
Mar 15, 2021
e3ae11d
unbreak a couple tests
Mar 15, 2021
c66202e
use non corrected variance
Mar 16, 2021
d6fac56
typo
Mar 16, 2021
16d0b96
use train time eval
Mar 16, 2021
e7fe00b
check for dims in getaffine
Mar 16, 2021
9c01dd2
use correct group dims
Mar 22, 2021
9abfe0c
typo
Mar 22, 2021
a621ef6
use trainmode groupnorm test
Mar 22, 2021
e174605
cleanup
Mar 24, 2021
99901a7
use bias and gamma for trainable
Mar 24, 2021
9f481e4
trainable
Mar 24, 2021
e9d89ab
test fixes
Mar 26, 2021
8f3844c
new constructor
DhairyaLGandhi Apr 19, 2021
bf34b73
test conflicts
DhairyaLGandhi Apr 19, 2021
14a6372
conflicts
DhairyaLGandhi Apr 19, 2021
d82c3d3
conflicts
DhairyaLGandhi Apr 19, 2021
8aadf1e
rebase
DhairyaLGandhi Jun 23, 2021
28521c1
rebase
DhairyaLGandhi Jun 23, 2021
19b91b2
size fix
DhairyaLGandhi Jun 24, 2021
0d4605d
space cleanups + show
DhairyaLGandhi Jun 24, 2021
36084e5
add layer norm show methods
DhairyaLGandhi Jun 24, 2021
3c6f1ce
whitespace
DhairyaLGandhi Jun 24, 2021
8f6de19
change some tests
DhairyaLGandhi Jun 25, 2021
c525d4f
use affine as function
DhairyaLGandhi Jun 29, 2021
aa39039
rebase
DhairyaLGandhi Aug 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import NNlibCUDA: batchnorm, ∇batchnorm

function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
cache=nothing) where T<:Union{Float32, Float64}
function (BN::Flux.BatchNorm)(x::CuArray{T},
cache = nothing) where T<:Union{Float32, Float64}
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved

@assert BN.affine "BatchNorm: only affine=true supported on gpu"
@assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu"
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wronng number of channels"
@assert BN.affine throw(ArgumentError("BatchNorm: only affine = true supported on gpu"))
@assert BN.track_stats throw(ArgumentError("BatchNorm: only track_stats = true supported on gpu"))
return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum;
cache=cache, alpha=1, beta=0, eps=BN.ϵ,
training=Flux._isactive(BN)))
cache = cache, alpha = 1, beta = 0, eps = BN.ϵ,
training = Flux._isactive(BN)))
end

@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
Expand Down
Loading