diff --git a/NEWS.md b/NEWS.md index 863107aa8c..db05c18067 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl * The `Flux.Optimise` module has been deprecated in favor of the Optimisers.jl package. Now Flux re-exports the optimisers from Optimisers.jl. Most users will be uneffected by this change. The module is still available for now, but will be removed in a future release. +* Most Flux layers will [re-use memory via `NNlib.bias_act!`](https://github.com/FluxML/Flux.jl/pull/2327), when possible. ## v0.14.22 * Data movement between devices is now provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl). diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 3c615ae06d..8aec354716 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -186,9 +186,8 @@ end function (a::Dense)(x::AbstractVecOrMat) _size_check(a, x, 1 => size(a.weight, 2)) - σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc xT = _match_eltype(a, x) # fixes Float64 input, etc. - return σ.(a.weight * xT .+ a.bias) + return NNlib.bias_act!(a.σ, a.weight * xT, a.bias) # does σ.(W*x .+ b), with fast paths end function (a::Dense)(x::AbstractArray) @@ -466,7 +465,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix) Z = reshape(Wyx, (d_z, :)) # @einsum out[o,s] := σ(Z[o,i] + b[o]) - σ.(Z .+ b) + NNlib.bias_act!(σ, Z, b) # σ.(Z .+ b) end (a::Bilinear)(x::AbstractVecOrMat) = a(x, x) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 14ed11e319..b2186f9abf 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -196,10 +196,9 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any) function (c::Conv)(x::AbstractArray) _conv_size_check(c, x) - σ = NNlib.fast_act(c.σ, x) cdims = conv_dims(c, x) xT = _match_eltype(c, x) - σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c)) + NNlib.bias_act!(c.σ, conv(xT, c.weight, cdims), conv_reshape_bias(c)) end _channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups @@ -350,10 +349,9 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any) function (c::ConvTranspose)(x::AbstractArray) _conv_size_check(c, x) - σ = NNlib.fast_act(c.σ, x) cdims = conv_transpose_dims(c, x) xT = _match_eltype(c, x) - σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c)) + NNlib.bias_act!(c.σ, ∇conv_data(xT, c.weight, cdims), conv_reshape_bias(c)) end function Base.show(io::IO, l::ConvTranspose) @@ -493,10 +491,9 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any) function (c::CrossCor)(x::AbstractArray) _conv_size_check(c, x) - σ = NNlib.fast_act(c.σ, x) cdims = crosscor_dims(c, x) xT = _match_eltype(c, x) - σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c)) + NNlib.bias_act!(c.σ, crosscor(xT, c.weight, cdims), conv_reshape_bias(c)) end function Base.show(io::IO, l::CrossCor) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 99092f9756..9d294e3e6e 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -245,7 +245,7 @@ function _norm_layer_forward( β = reshape(l.β, affine_shape) scale = γ ./ sqrt.(σ² .+ eps) - bias = -scale .* μ .+ β + bias = .-scale .* μ .+ β l.λ.(scale .* x .+ bias) end