Skip to content

Commit

Permalink
rm Flux.Zeros, take N+1
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 19, 2022
1 parent f49e81e commit bf66393
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 151 deletions.
1 change: 0 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ using CUDA
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)

include("utils.jl")
include("zeros.jl")
include("onehot.jl")
include("functor.jl")

Expand Down
8 changes: 8 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,

@deprecate frequencies(xs) group_counts(xs)

struct Zeros
function Zeros()
Base.depwarn("Flux.Zeros is no more, has ceased to be, is bereft of life, is an ex-boondoggle... please use bias=false instead", :Zeros)
false
end
end
Zeros(args...) = Zeros() # was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros())

# Channel notation: Changed to match Conv, but very softly deprecated!
# Perhaps change to @deprecate for v0.14, but there is no plan to remove these.
Dense(in::Integer, out::Integer, σ = identity; kw...) =
Expand Down
6 changes: 5 additions & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ end
function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.weight, 2), " => ", size(l.weight, 1))
l.σ == identity || print(io, ", ", l.σ)
l.bias == Zeros() && print(io, "; bias=false")
l.bias == false && print(io, "; bias=false")
print(io, ")")
end

Expand Down Expand Up @@ -394,7 +394,11 @@ function Base.show(io::IO, l::Bilinear)
print(io, "Bilinear((", size(l.weight, 2), ", ", size(l.weight, 3), ") => ", size(l.weight, 1))
end
l.σ == identity || print(io, ", ", l.σ)
<<<<<<< HEAD
l.bias == Flux.Zeros() && print(io, "; bias=false")
=======
l.bias === false && print(io, ", bias=false")
>>>>>>> 1ef2cd377 (rm Flux.Zeros, take N+1)
print(io, ")")
end

Expand Down
22 changes: 11 additions & 11 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)

conv_reshape_bias(c) = c.bias isa AbstractVector ?
reshape(c.bias, map(_->1, c.stride)..., :, 1) :
c.bias

"""
SamePad()
Expand Down Expand Up @@ -61,8 +65,8 @@ Then:
Keywords to control initialization of the layer:
* `init` - Function used to generate initial weights. Defaults to `glorot_uniform`.
* `bias` - Initial bias is zero by default, this can be disabled entirely by setting it to
`false`, or another vector explicitly as `bias = randn(Float32, out)`.
* `bias` - The initial bias vector is all zero by default. Trainable bias can be disabled entirely
by setting this to `false`, or another vector can be provided such as `bias = randn(Float32, out)`.
See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref).
Expand Down Expand Up @@ -159,10 +163,9 @@ end
@functor Conv

function (c::Conv)(x::AbstractArray)
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ = NNlib.fast_act(c.σ, x)
cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
σ.(conv(x, c.weight, cdims) .+ b)
σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
end

_channels_in(l ::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
Expand All @@ -183,7 +186,7 @@ function _print_conv_opt(io::IO, l)
if hasproperty(l, :groups)
(l.groups == 1) || print(io, ", groups=", l.groups)
end
(l.bias isa Zeros) && print(io, ", bias=false")
(l.bias === false) && print(io, ", bias=false")
end

"""
Expand Down Expand Up @@ -277,10 +280,9 @@ end
@nograd conv_transpose_dims

function (c::ConvTranspose)(x::AbstractArray)
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_transpose_dims(c, x)
σ.(∇conv_data(x, c.weight, cdims) .+ b)
σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
end

function Base.show(io::IO, l::ConvTranspose)
Expand Down Expand Up @@ -372,10 +374,9 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])

function (c::DepthwiseConv)(x)
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ = NNlib.fast_act(c.σ, x)
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
σ.(depthwiseconv(x, c.weight, cdims) .+ conv_reshape_bias(c))
end

function Base.show(io::IO, l::DepthwiseConv)
Expand Down Expand Up @@ -453,10 +454,9 @@ function crosscor(x, w, ddims::DenseConvDims)
end

function (c::CrossCor)(x::AbstractArray)
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ = NNlib.fast_act(c.σ, x)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(crosscor(x, c.weight, cdims) .+ b)
σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
end

function Base.show(io::IO, l::CrossCor)
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,17 +441,17 @@ rand32(dims...) = Base.rand(Float32, dims...)
randn32(dims...) = Base.randn(Float32, dims...)

"""
create_bias(weights, bias, length)
create_bias(weights, bias, size...)
Return a bias parameter for a layer, based on the value given
to the constructor's keyword `bias=bias`.
* `bias == true` creates a zero vector, of the same type as weights.
* `bias == false` returns `Zeros()`, a special struct which exists only to encode the absence of bias.
* `bias == false` returns `false` now, which is understood by AD to be non-differentiable.
* `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted.
"""
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
bias ? fill!(similar(weights, dims...), 0) : Zeros()
bias ? fill!(similar(weights, dims...), 0) : false
end
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
Expand Down
52 changes: 0 additions & 52 deletions src/zeros.jl

This file was deleted.

4 changes: 2 additions & 2 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ end
end
end

@testset "Dense with Zeros bias" begin
l = Dense(ones(Float32, 4, 3), Flux.Zeros()) |> gpu
@testset "Dense without bias" begin
l = Dense(ones(Float32, 4, 3), false) |> gpu
ip = zeros(Float32, 3, 7) |> gpu

@test sum(l(ip)) 0.f0
Expand Down
2 changes: 1 addition & 1 deletion test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ import Flux: activations
@test b1.σ == identity

b2 = Flux.Bilinear(randn(3,4,5), false)
@test b2.bias == Flux.Zeros()
@test b2.bias === false

b3 = Flux.Bilinear(randn(Float16, 3,4,5), true, tanh)
@test b3.σ == tanh
Expand Down
2 changes: 1 addition & 1 deletion test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ end

@testset "constructors: $fun" for fun in [Conv, CrossCor, ConvTranspose, DepthwiseConv]
@test fun(rand(2,3,4)).bias isa Vector{Float64}
@test fun(rand(2,3,4,5), false).bias isa Flux.Zeros
@test fun(rand(2,3,4,5), false).bias === false
if fun == Conv
@test fun(rand(2,3,4,5,6), rand(6)).bias isa Vector{Float64}
@test_skip fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}
Expand Down
2 changes: 1 addition & 1 deletion test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Random
Nesterov(), RMSProp(), Momentum()]
Random.seed!(42)
w′ = randn(10, 10)
b = Flux.Zeros()
b = false
loss(x) = Flux.Losses.mse(w*x, w′*x .+ b)
for t = 1: 10^5
θ = params([w′, b])
Expand Down
100 changes: 22 additions & 78 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
sparse_init, stack, unstack, Zeros, batch, unbatch,
unsqueeze, params
sparse_init, stack, unstack, batch, unbatch,
unsqueeze, params, loadparams!
using StatsBase: var, std
using Statistics, LinearAlgebra
using Random
Expand Down Expand Up @@ -263,88 +263,36 @@ end
@test eltype(f32(f64(m))[1].weight) == Float32
end

@testset "Zeros" begin
@testset "zero bias" begin
m = Dense(3,2; bias=false)
@test f64(m).bias === m.bias === Zeros()
@test f32(m).bias === m.bias === Zeros()
@test f64(m).bias === m.bias === false
@test f32(m).bias === m.bias === false

@testset "Gradients for broadcasted $op with sizes $s" for op in (+,-,*), s in ((1,), (2,3))
o = ones(s)
z = zeros(s)
Z = Zeros()

@testset "Explicit" begin
gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...)
g = gfun(o, z)
@test gfun(o, Z) == (g[1], nothing)
@test gfun(o, false) == (g[1], nothing)

g = gfun(z, o)
@test gfun(Z, o) == (nothing, g[2])
@test gfun(false, o) == (nothing, g[2])
end

@testset "Implicit" begin
gfun(args...) = gradient(() -> sum(op.(args...)), params(collect(args)))
g = gfun(o, z)

gres = gfun(o, Z)
gres = gfun(o, false)
@test gres[o] == g[o]
@test Z gres.params
@test false gres.params

g = gfun(z, o)
gres = gfun(Z, o)
gres = gfun(false, o)
@test gres[o] == g[o]
@test Z gres.params
end
end

@testset "Gradients for broadcasted / with sizes $s" for s in ((1,), (2,3))
o = ones(s)
z = zeros(s)
Z = Zeros() # Only defined for 0-dim

@testset "Explicit" begin
gfun(args...) = gradient((x, y) -> sum(x ./ y), args...)
g = gfun(z, o)
@test gfun(Z, o) == (nothing, g[2])
end

@testset "Implicit" begin
gfun(x,y) = gradient(() -> sum(x ./ y), params([x,y]))

g = gfun(z, o)
gres = gfun(Z, o)
@test gres[o] == g[o]
@test Z gres.params
end
end

@testset "Gradients for $op with sizes $s" for op in (+,-), s in (tuple(), (1,), (2,3))
o = ones(s)
z = zeros(s)
Z = Zeros()


@testset "Explicit" begin
gfun(args...) = gradient((x, y) -> sum(op(x,y)), args...)

g = gfun(o, z)
@test gfun(o, Z) == (g[1], nothing)

g = gfun(z, o)
@test gfun(Z, o) == (nothing, g[2])
end

@testset "Implicit" begin
gfun(args...) = gradient(() -> sum(op(args...)), params(collect(args)))
g = gfun(o, z)
gres = gfun(o, Z)
@test gres[o] == g[o]
@test Z gres.params

g = gfun(z, o)
gres = gfun(Z, o)
@test gres[o] == g[o]
@test Z gres.params
@test false gres.params
end
end
end
Expand Down Expand Up @@ -385,19 +333,15 @@ end
dl(4, 3, bias)
)

nobias(n) = Zeros()
nobias(n) = false
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
@test l1.weight == l2.weight
@test l1.bias == l2.bias
@test_skip typeof(l1.bias) === typeof(l2.bias)
end

@testset "loadparams!" begin
import Flux: loadparams!
pars(w, b) = [w, b]
import Flux: loadparams!, Zeros

pars(w, b::Zeros) = [w, Flux.zeros32(size(w,1))]
pars(l) = pars(l.weight, l.bias)
pararray(m) = mapreduce(pars, vcat, m)
weights(m) = mapreduce(l -> [l.weight], vcat, m)
Expand All @@ -407,16 +351,16 @@ end
testdense(m, bt)
end

@testset "$b1 to $b2" for (b1, b2, be) in (
(Flux.zeros32, Flux.ones32, Flux.ones32), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
(Flux.ones32, nobias, Flux.zeros32), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
(nobias, Flux.ones32, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change
)
m1 = dm(b1)
m2 = dm(b2)
loadparams!(m1, b1 == nobias ? weights(m2) : pararray(m2))
testdense(m1, be)
end
# @testset "$b1 to $b2" for (b1, b2, be) in (
# (Flux.zeros32, Flux.ones32, Flux.ones32), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
# (Flux.ones32, nobias, Flux.zeros32), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
# (nobias, Flux.ones32, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change
# )
# m1 = dm(b1)
# m2 = dm(b2)
# loadparams!(m1, b1 == nobias ? weights(m2) : pararray(m2))
# testdense(m1, be)
# end
end

@testset "destructure" begin
Expand Down

0 comments on commit bf66393

Please sign in to comment.