Skip to content

Commit

Permalink
Merge pull request #1472 from FluxML/dg/acttests
Browse files Browse the repository at this point in the history
Add activation tests for GPU layers
  • Loading branch information
DhairyaLGandhi authored Mar 24, 2021
2 parents 3e3f9d7 + ae699d0 commit 5c40716
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 75 deletions.
1 change: 0 additions & 1 deletion src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ If `d` is a tuple of arguments to `loss` call `loss(d...)`, else call `loss(d)`.
A callback is given with the keyword argument `cb`. For example, this will print
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
Expand Down
186 changes: 113 additions & 73 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,71 +13,98 @@ end
# TODO: These layers get into scalar indexing
# `AlphaDropout` throws a compilation error on GPUs,
# whereas, the rest are scalar indexing issues.
# The norm layers behave differently on the CPU and
# the GPU too.
const BROKEN_LAYERS = Union{DepthwiseConv,
AlphaDropout}

function gpu_gradtest(name::String, layers::Vector, x_cpu, args...;
setmode=false, test_cpu=true, rtol=1e-5, atol=1e-5)
const ACTIVATIONS = [identity, relu, tanh,
sigmoid, exp, softplus,
elu, selu]

function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; test_cpu = true)
isnothing(x_cpu) && error("Missing input to test the layers against.")
@testset "$name GPU grad tests" begin
for layer in layers
@testset "$layer GPU grad test" begin
@testset "$layer Layer GPU grad test" begin

# compute output and grad of parameters
l_cpu = layer(args...)
if l_cpu isa BROKEN_LAYERS
l_gpu, x_gpu = l_cpu |> gpu, x_cpu |> gpu
@test_broken gradient(() -> sum(l_gpu(x_gpu)), Flux.params(l_gpu)) isa Flux.Zygote.Grads
ps_cpu = Flux.params(l_cpu)
y_cpu, back_cpu = pullback(() -> sum(l_cpu(x_cpu)), ps_cpu)
gs_cpu = back_cpu(1f0)

x_gpu = gpu(x_cpu)
l_gpu = l_cpu |> gpu
ps_gpu = Flux.params(l_gpu)

if typeof(l_gpu) <: BROKEN_LAYERS
@test_broken gradient(() -> sum(l_gpu(x_gpu)), ps_gpu) isa Flux.Zygote.Grads
else
gpu_autodiff_test(l_cpu, x_cpu,
test_equal=test_cpu, rtol=rtol, atol=atol)
if setmode
testmode!(l_cpu)
gpu_autodiff_test(l_cpu, x_cpu,
test_equal=test_cpu, rtol=rtol, atol=atol)
end
y_gpu, back_gpu = pullback(() -> sum(l_gpu(x_gpu)), ps_gpu)
gs_gpu = back_gpu(1f0) # TODO many layers error out when backprop int 1, should fix

# compute grad of input
xg_cpu = gradient(x -> sum(l_cpu(x)), x_cpu)[1]
xg_gpu = gradient(x -> sum(l_gpu(x)), x_gpu)[1]

# test
if test_cpu
@test y_gpu y_cpu rtol=1f-3 atol=1f-3
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
end
@test gs_gpu isa Flux.Zygote.Grads
for (p_cpu, p_gpu) in zip(ps_cpu, ps_gpu)
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
if test_cpu
@test Array(gs_gpu[p_gpu]) gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
end
end
end
end
end
end
end

# Just to give testset in gpu_gradtest meaningful labels
ConvNoBias(args...) = Conv(args...; bias = false)
ConvTransposeNoBias(args...) = ConvTranspose(args...; bias = false)
CrossCorNoBias(args...) = CrossCor(args...; bias = false)
DepthwiseConvNoBias(args...) = DepthwiseConv(args...; bias = false)

for act in ACTIVATIONS
r = rand(Float32, 28, 28, 1, 1)
conv_layers = [Conv, ConvNoBias,
ConvTranspose, ConvTransposeNoBias,
CrossCor, CrossCorNoBias,
DepthwiseConv, DepthwiseConvNoBias]
gpu_gradtest("Convolution with $act", conv_layers, r, (2,2), 1=>3, act, test_cpu = false)

batch_norm = [BatchNorm]
gpu_gradtest("BatchNorm 1 with $act", batch_norm, rand(Float32, 28,28,3,4), 3, act, test_cpu = false) #TODO fix errors
gpu_gradtest("BatchNorm 2 with $act", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = false)

instancenorm = [InstanceNorm]
gpu_gradtest("InstanceNorm with $act", instancenorm, r, 1, act, test_cpu = false)

groupnorm = [GroupNorm]
gpu_gradtest("GroupNorm with $act", groupnorm, rand(Float32, 28,28,3,1), 3, 1, act, test_cpu = false)
end

# Just to give testset in gradtest meaningful labels
ConvNoBias(args...) = Conv(args...; bias=false)
ConvTransposeNoBias(args...) = ConvTranspose(args...; bias=false)
CrossCorNoBias(args...) = CrossCor(args...; bias=false)
DepthwiseConvNoBias(args...) = DepthwiseConv(args...; bias=false)
r = rand(Float32, 28, 28, 1, 1)
conv_layers = [Conv, ConvNoBias, ConvTranspose, ConvTransposeNoBias, CrossCor, CrossCorNoBias, DepthwiseConv, DepthwiseConvNoBias]
gpu_gradtest("Conv", conv_layers, r, (2,2), 1=>3)

pooling_layers = [MaxPool, MeanPool]
gpu_gradtest("Pooling", pooling_layers, r, (2,2))

adaptive_pooling_layers = [AdaptiveMaxPool, AdaptiveMeanPool]
gpu_gradtest("AdaptivePooling", adaptive_pooling_layers, r, (7,7))
gpu_gradtest("AdaptivePooling", adaptive_pooling_layers, r, (7,7), test_cpu = false)

dropout_layers = [Dropout, AlphaDropout]
gpu_gradtest("Dropout", dropout_layers, r, 0.5f0; test_cpu=false, setmode=true) # dropout is not deterministic

layer_norm = [i -> LayerNorm(i; affine=false), i -> LayerNorm(i; affine=true)]
gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 8, 8, 3, 4), 8)
gpu_gradtest("LayerNorm 2", layer_norm, rand(Float32, 8, 8, 3, 4), (8,8))
gpu_gradtest("LayerNorm 3", layer_norm, rand(Float32, 5, 4), 5)

batch_norm = [BatchNorm]
gpu_gradtest("BatchNorm 3d", batch_norm, rand(Float32, 8, 8, 8, 3, 4), 3, setmode=false) # bug in CUDA.jl with gradient in testmode
gpu_gradtest("BatchNorm 2d", batch_norm, rand(Float32, 8, 8, 3, 4), 3, setmode=false) # bug in CUDA.jl with gradient in testmode
gpu_gradtest("BatchNorm 1d", batch_norm, rand(Float32, 8, 3, 4), 3, setmode=false) # bug in CUDA.jl with gradient in testmode
gpu_gradtest("BatchNorm fullyconn", batch_norm, rand(Float32, 5,4), 5, setmode=false)
gpu_gradtest("Dropout", dropout_layers, r, 0.5f0; test_cpu = false) # dropout is not deterministic

instancenorm = [i -> InstanceNorm(i; affine=false), i -> InstanceNorm(i; affine=true)]
gpu_gradtest("InstanceNorm 3d", instancenorm, rand(Float32, 8, 8, 8, 3, 4), 3, setmode=true)
gpu_gradtest("InstanceNorm 2d", instancenorm, rand(Float32, 8, 8, 3, 4), 3, setmode=true)
gpu_gradtest("InstanceNorm 1d", instancenorm, rand(Float32, 8, 3, 4), 3, setmode=true)

groupnorm = [(i, j) -> GroupNorm(i, j; affine=false), (i, j) -> GroupNorm(i, j; affine=true)]
gpu_gradtest("GroupNorm 3d", groupnorm, rand(Float32, 8, 8, 8, 12, 4), 12, 3, setmode=true)
gpu_gradtest("GroupNorm 2d", groupnorm, rand(Float32, 8, 8, 12, 4), 12, 3, setmode=true)
gpu_gradtest("GroupNorm 1d", groupnorm, rand(Float32, 8, 3, 12, 4), 12, 3, setmode=true)
layer_norm = [LayerNorm]
gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 28,28,3,4), 1, test_cpu = false) #TODO fix errors
gpu_gradtest("LayerNorm 2", layer_norm, rand(Float32, 5,4), 5)

upsample = [x -> Upsample(scale=x)]
gpu_gradtest("Upsample 2d", upsample, rand(Float32, 3, 4, 2, 3), (2,2))
Expand All @@ -87,27 +114,48 @@ pixelshuffle = [PixelShuffle]
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)


@testset "function layers" begin
x = rand(Float32, 3,3)
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=2)), x)
gpu_autodiff_test(x -> sum(Flux.normalise(x)), x)
end

@testset "BatchNorm mix stuff" begin
@testset "Zeros mapped for $cl" for cl in (Conv, ConvTranspose, CrossCor, DepthwiseConv)
l = cl((2,2), 1=>3, bias = false) |> gpu
ip = zeros(Float32, 28,28,1,1) |> gpu
if typeof(l) <: BROKEN_LAYERS
@test_broken sum(l(ip)) 0.f0
@test_broken gradient(() -> sum(l(ip)), Flux.params(l)) isa Flux.Zygote.Grads
else
@test sum(l(ip)) 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.bias gs.params
end
end

@testset "Dense with Zeros bias" begin
l = Dense(ones(Float32, 4,3), Flux.Zeros()) |> gpu
ip = zeros(Float32, 3, 7) |> gpu

@test sum(l(ip)) 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.b gs.params
end

@testset "Extended BatchNorm" begin
m_cpu = BatchNorm(2)
m_gpu = m_cpu |> gpu
x_cpu = rand(Float32, 3, 2, 2)
x_gpu = x_cpu |> gpu

## In :auto mode, track statistics only in gradient contest
μ_cpu = copy(m_cpu.μ)
m_cpu(x_cpu)
@test m_cpu.μ μ_cpu
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
@test !(m_cpu.μ μ_cpu)

μ_gpu = copy(m_gpu.μ)
m_gpu(x_gpu)
@test m_gpu.μ μ_gpu
Expand All @@ -123,7 +171,7 @@ end
@test m_cpu.μ μ_cpu
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
@test m_cpu.μ μ_cpu

testmode!(m_gpu)
μ_gpu = copy(m_gpu.μ)
m_gpu(x_gpu)
Expand All @@ -139,7 +187,7 @@ end
μ_cpu = copy(m_cpu.μ)
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
@test !(m_cpu.μ μ_cpu)

trainmode!(m_gpu)
μ_gpu = copy(m_gpu.μ)
m_gpu(x_gpu)
Expand All @@ -149,36 +197,28 @@ end
@test !(m_gpu.μ μ_gpu)

## No errors if input type mistmatch
x_cpu = rand(Float64, 3, 2, 2)
x_gpu = x_cpu |> gpu
m_cpu(x_cpu)
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
m_gpu(x_gpu)
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
# x_cpu = rand(Float64, 3, 2, 2)
# x_gpu = x_cpu |> gpu
# m_cpu(x_cpu)
# gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
# m_gpu(x_gpu)
# gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
end

@testset "Zeros mapped for $cl" for cl in (Conv, ConvTranspose, CrossCor, DepthwiseConv)
l = cl((2,2), 1=>3, bias = false) |> gpu
ip = zeros(Float32, 28,28,1,1) |> gpu
if l isa BROKEN_LAYERS
@test_broken sum(l(ip)) 0.f0
@test_broken gradient(() -> sum(l(ip)), Flux.params(l)) isa Flux.Zygote.Grads
else
@test sum(l(ip)) 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.bias gs.params
@testset "Two-streams Bilinear" begin
x = zeros(Float32,10,9) |> gpu
y = zeros(Float32,2,9) |> gpu
b = Flux.Bilinear(10, 2, 3) |> gpu
@test size(b(x,y)) == (3,9)
@test sum(abs2, b(x,y)) 0f0
gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b))
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu))
for (pgpu, pcpu) in zip(params(b), params(b_cpu))
@test gs_cpu[pcpu] Array(gs_gpu[pgpu])
end
end

@testset "Dense with Zeros bias" begin
l = Dense(ones(Float32, 4,3), Flux.Zeros()) |> gpu
ip = zeros(Float32, 3, 7) |> gpu

@test sum(l(ip)) 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.b gs.params
end

@testset "Two-streams Bilinear" begin
x = zeros(Float32,10,9) |> gpu
y = zeros(Float32,2,9) |> gpu
Expand Down
3 changes: 2 additions & 1 deletion test/cuda/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Flux, Test, CUDA
using Zygote
using Zygote: pullback

@info "Testing GPU Support"
Expand All @@ -15,4 +16,4 @@ if CUDA.has_cudnn()
include("curnn.jl")
else
@warn "CUDNN unavailable, not testing GPU DNN support"
end
end

0 comments on commit 5c40716

Please sign in to comment.