Skip to content

Commit

Permalink
Merge pull request #2081 from Saransh-cpp/create_bias
Browse files Browse the repository at this point in the history
Back to create_bias
  • Loading branch information
ToucheSir authored Oct 12, 2022
2 parents b08cb67 + 2a1e4d2 commit c7ed5fe
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
6 changes: 3 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ struct Dense{F, M<:AbstractMatrix, B}
bias::B
σ::F
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
b = _create_bias(W, bias, size(W,1))
b = create_bias(W, bias, size(W,1))
new{F,M,typeof(b)}(W, b, σ)
end
end
Expand Down Expand Up @@ -228,7 +228,7 @@ struct Scale{F, A<:AbstractArray, B}
bias::B
σ::F
function Scale(scale::A, bias::B = true, σ::F = identity) where {A<:AbstractArray, B<:Union{Bool, AbstractArray}, F}
b = _create_bias(scale, bias, size(scale)...)
b = create_bias(scale, bias, size(scale)...)
new{F, A, typeof(b)}(scale, b, σ)
end
end
Expand Down Expand Up @@ -403,7 +403,7 @@ struct Bilinear{F,A,B}
σ::F
function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray, F}
ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights"))
b = _create_bias(W, bias, size(W,1))
b = create_bias(W, bias, size(W,1))
new{F,A,typeof(b)}(W, b, σ)
end
end
Expand Down
6 changes: 3 additions & 3 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function Conv(w::AbstractArray{T,N}, b = true, σ = identity;
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride)
bias = _create_bias(w, b, size(w, N))
bias = create_bias(w, b, size(w, N))
return Conv(σ, w, bias, stride, pad, dilation, groups)
end

Expand Down Expand Up @@ -293,7 +293,7 @@ function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
b = _create_bias(w, bias, size(w, N-1) * groups)
b = create_bias(w, bias, size(w, N-1) * groups)
return ConvTranspose(σ, w, b, stride, pad, dilation, groups)
end

Expand Down Expand Up @@ -441,7 +441,7 @@ function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(CrossCor, pad, size(w)[1:N-2], dilation, stride)
b = _create_bias(w, bias, size(w, N))
b = create_bias(w, bias, size(w, N))
return CrossCor(σ, w, b, stride, pad, dilation)
end

Expand Down
9 changes: 3 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ randn32(rng::AbstractRNG, dims::Integer...) = Base.randn(rng, Float32, dims...)
randn32(rng::AbstractRNG) = (dims...,) -> Base.randn(rng, Float32, dims...)

"""
_create_bias(weights, bias, size...)
create_bias(weights, bias, size...)
Return a bias parameter for a layer, based on the value given
to the constructor's keyword `bias=bias`.
Expand All @@ -514,17 +514,14 @@ to the constructor's keyword `bias=bias`.
* `bias::AbstractArray` uses the array provided, provided it has the correct size.
It does not at present correct the `eltype` to match that of `weights`.
"""
function _create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
bias ? fill!(similar(weights, dims...), 0) : false
end
function _create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
bias
end

# TODO figure out whether we want to document or deprecate this
const create_bias = _create_bias


# Other

Expand Down

0 comments on commit c7ed5fe

Please sign in to comment.