Skip to content

Commit

Permalink
use NNlib.bias_act
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 4, 2023
1 parent 2ac01e0 commit 368cec2
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.14.4"
version = "0.14.5"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -42,7 +42,7 @@ Functors = "0.4"
MLUtils = "0.4"
MacroTools = "0.5"
Metal = "0.5"
NNlib = "0.9.1"
NNlib = "0.9.5"
OneHotArrays = "0.2.4"
Optimisers = "0.2.12, 0.3.0"
Preferences = "1"
Expand Down
5 changes: 2 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,8 @@ end

function (a::Dense)(x::AbstractVecOrMat)
_size_check(a, x, 1 => size(a.weight, 2))
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
xT = _match_eltype(a, x) # fixes Float64 input, etc.
return σ.(a.weight * xT .+ a.bias)
NNlib.bias_act!(a.σ, a.weight * xT, a.bias) # does σ.(W*x .+ b), with fast paths
end

function (a::Dense)(x::AbstractArray)
Expand Down Expand Up @@ -446,7 +445,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
Z = reshape(Wyx, (d_z, :))

# @einsum out[o,s] := σ(Z[o,i] + b[o])
σ.(Z .+ b)
NNlib.bias_act!(σ, Z, b) # σ.(Z .+ b)
end

(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)
Expand Down
15 changes: 9 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,11 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)

function (c::Conv)(x::AbstractArray)
_size_check(c, x, ndims(x)-1 => _channels_in(c))
σ = NNlib.fast_act(c.σ, x)
# σ = NNlib.fast_act(c.σ, x)
cdims = conv_dims(c, x)
xT = _match_eltype(c, x)
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
# σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
NNlib.bias_act!(c.σ, conv(xT, c.weight, cdims), conv_reshape_bias(c))
end

_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
Expand Down Expand Up @@ -331,10 +332,11 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)

function (c::ConvTranspose)(x::AbstractArray)
_size_check(c, x, ndims(x)-1 => _channels_in(c))
σ = NNlib.fast_act(c.σ, x)
# σ = NNlib.fast_act(c.σ, x)
cdims = conv_transpose_dims(c, x)
xT = _match_eltype(c, x)
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
# σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
NNlib.bias_act!(c.σ, ∇conv_data(xT, c.weight, cdims), conv_reshape_bias(c))
end

function Base.show(io::IO, l::ConvTranspose)
Expand Down Expand Up @@ -473,10 +475,11 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)

function (c::CrossCor)(x::AbstractArray)
_size_check(c, x, ndims(x)-1 => _channels_in(c))
σ = NNlib.fast_act(c.σ, x)
# σ = NNlib.fast_act(c.σ, x)
cdims = crosscor_dims(c, x)
xT = _match_eltype(c, x)
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
# σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
NNlib.bias_act!(c.σ, crosscor(xT, c.weight, cdims), conv_reshape_bias(c))
end

function Base.show(io::IO, l::CrossCor)
Expand Down
2 changes: 1 addition & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ function _norm_layer_forward(
β = reshape(l.β, affine_shape)

scale = γ ./ sqrt.(σ² .+ eps)
bias = -scale .* μ .+ β
bias = .-scale .* μ .+ β
l.λ.(scale .* x .+ bias)
end

Expand Down

0 comments on commit 368cec2

Please sign in to comment.