Skip to content

Commit

Permalink
Change cells' return to out, state (#2551)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Dec 15, 2024
1 parent 6041cf5 commit 0683976
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 136 deletions.
13 changes: 10 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@

See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.

## v0.15.3
* Add `WeightNorm` normalization layer.
## v0.16.0 (15 December 2025)
This release has a single **breaking change**:

## v0.15.0 (December 2024)
- The recurrent cells `RNNCell`, `LSTMCell`, and `GRUCell` forward has been changed to
$y_t, state_t = cell(x_t, state_{t-1})$. Previously, it was $state_t = cell(x_t, state_{t-1})$.

Other highlights include:
* Added `WeightNorm` normalization layer.
* Added `Recurrence` layer, turning a recurrent layer into a layer processing the entire sequence at once.

## v0.15.0 (5 December 2024)
This release includes two **breaking changes**:
- The recurrent layers have been thoroughly revised. See below and read the [documentation](https://fluxml.ai/Flux.jl/v0.15/guide/models/recurrence/) for details.
- Flux now defines and exports its own gradient function. Consequently, using gradient in an unqualified manner (e.g., after `using Flux, Zygote`) could result in an ambiguity error.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.15.2"
version = "0.16.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
23 changes: 11 additions & 12 deletions docs/src/guide/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ b = zeros(Float32, output_size)

function rnn_cell(x, h)
h = tanh.(Wxh * x .+ Whh * h .+ b)
return h
return h, h
end

seq_len = 3
Expand All @@ -33,14 +33,14 @@ h0 = zeros(Float32, output_size)
y = []
ht = h0
for xt in x
ht = rnn_cell(xt, ht)
y = [y; [ht]] # concatenate in non-mutating (AD friendly) way
yt, ht = rnn_cell(xt, ht)
y = [y; [yt]] # concatenate in non-mutating (AD friendly) way
end
```

Notice how the above is essentially a `Dense` layer that acts on two inputs, `xt` and `ht`.

The output at each time step, called the hidden state, is used as the input to the next time step and is also the output of the model.
The result of the forward pass at each time step, is a tuple contening the output `yt` and the updated state `ht`. The updated state is used as an input in next iteration. In the simple case of a vanilla RNN, the
output and the state are the same. In more complex cells, such as `LSTMCell`, the state can contain multiple arrays.

There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCell` and `GRUCell`, which are documented in the [layer reference](../../reference/models/layers.md). The hand-written example above can be replaced with:

Expand All @@ -58,8 +58,8 @@ rnn_cell = Flux.RNNCell(input_size => output_size)
y = []
ht = h0
for xt in x
ht = rnn_cell(xt, ht)
y = [y; [ht]]
yt, ht = rnn_cell(xt, ht)
y = [y; [yt]]
end
```
The entire output `y` or just the last output `y[end]` can be used for further processing, such as classification or regression.
Expand All @@ -78,7 +78,7 @@ struct RecurrentCellModel{H,C,D}
end

# we choose to not train the initial hidden state
Flux.@layer RecurrentCellModel trainable=(cell,dense)
Flux.@layer RecurrentCellModel trainable=(cell, dense)

function RecurrentCellModel(input_size::Int, hidden_size::Int)
return RecurrentCellModel(
Expand All @@ -91,8 +91,8 @@ function (m::RecurrentCellModel)(x)
z = []
ht = m.h0
for xt in x
ht = m.cell(xt, ht)
z = [z; [ht]]
yt, ht = m.cell(xt, ht)
z = [z; [yt]]
end
z = stack(z, dims=2) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len]
= m.dense(z) # [1, seq_len, batch_size] or [1, seq_len]
Expand All @@ -109,7 +109,6 @@ using Optimisers: AdamW

function loss(model, x, y)
= model(x)
y = stack(y, dims=2)
return Flux.mse(ŷ, y)
end

Expand All @@ -123,7 +122,7 @@ model = RecurrentCellModel(input_size, 5)
opt_state = Flux.setup(AdamW(1e-3), model)

# compute the gradient and update the model
g = gradient(m -> loss(m, x, y),model)[1]
g = gradient(m -> loss(m, x, y), model)[1]
Flux.update!(opt_state, model, g)
```

Expand Down
78 changes: 36 additions & 42 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]

function scan(cell, x, state)
y = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
out = out_from_state(state)
y = vcat(y, [out])
yt, state = cell(x_t, state)
y = vcat(y, [yt])
end
return stack(y, dims = 2)
end
Expand Down Expand Up @@ -85,7 +81,6 @@ In the forward pass, implements the function
```math
h^\prime = \sigma(W_i x + W_h h + b)
```
and returns `h'`.
See [`RNN`](@ref) for a layer that processes entire sequences.
Expand All @@ -107,6 +102,9 @@ The arguments of the forward pass are:
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns a tuple `(output, state)`, where both elements are given by the updated state `h'`,
a tensor of size `out` or `out x batch_size`.
# Examples
```julia
Expand All @@ -123,10 +121,10 @@ h = zeros(Float32, 5)
ŷ = []
for x_t in x
h = r(x_t, h)
ŷ = [ŷ..., h] # Cannot use `push!(ŷ, h)` here since mutation
# is not automatic differentiation friendly yet.
# Can use `y = vcat(y, [h])` as an alternative.
yt, h = r(x_t, h)
ŷ = [ŷ..., yt] # Cannot use `push!(ŷ, h)` here since mutation
# is not automatic differentiation friendly yet.
# Can use `y = vcat(y, [h])` as an alternative.
end
h # The final hidden state
Expand Down Expand Up @@ -155,40 +153,37 @@ using Flux
rnn = RNNCell(10 => 20)
# Get the initial hidden state
h0 = initialstates(rnn)
state = initialstates(rnn)
# Get some input data
x = rand(Float32, 10)
# Run forward
res = rnn(x, h0)
out, state = rnn(x, state)
```
"""
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

function RNNCell(
(in, out)::Pair,
σ = tanh;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)
(in, out)::Pair,
σ = tanh;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)
Wi = init_kernel(out, in)
Wh = init_recurrent_kernel(out, out)
b = create_bias(Wi, bias, size(Wi, 1))
return RNNCell(σ, Wi, Wh, b)
end

function (rnn::RNNCell)(x::AbstractVecOrMat)
state = initialstates(rnn)
return rnn(x, state)
end
(rnn::RNNCell)(x::AbstractVecOrMat) = rnn(x, initialstates(rnn))

function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi, 2))
σ = NNlib.fast_act(m.σ, x)
h = σ.(m.Wi * x .+ m.Wh * h .+ m.bias)
return h
return h, h
end

function Base.show(io::IO, m::RNNCell)
Expand Down Expand Up @@ -278,10 +273,7 @@ function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
return RNN(cell)
end

function (rnn::RNN)(x::AbstractArray)
state = initialstates(rnn)
return rnn(x, state)
end
(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn))

function (m::RNN)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
Expand Down Expand Up @@ -315,7 +307,6 @@ o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)
h_t = o_t \odot \tanh(c_t)
```
The `LSTMCell` returns the new hidden state `h_t` and cell state `c_t` for a single time step.
See also [`LSTM`](@ref) for a layer that processes entire sequences.
# Arguments
Expand All @@ -336,7 +327,8 @@ The arguments of the forward pass are:
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`.
Returns a tuple `(output, state)`, where `output = h'` is the new hidden state and `state = (h', c')` is the new hidden and cell states.
These are tensors of size `out` or `out x batch_size`.
# Examples
Expand All @@ -350,9 +342,9 @@ julia> c = zeros(Float32, 5); # cell state
julia> x = rand(Float32, 3, 4); # in x batch_size
julia> h′, c′ = l(x, (h, c));
julia> y, (h′, c′) = l(x, (h, c));
julia> size(h′) # out x batch_size
julia> size(y) # out x batch_size
(5, 4)
```
"""
Expand Down Expand Up @@ -389,9 +381,9 @@ function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
b = m.bias
g = m.Wi * x .+ m.Wh * h .+ b
input, forget, cell, output = chunk(g, 4; dims = 1)
c = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
h = @. sigmoid_fast(output) * tanh_fast(c)
return h′, c′
c = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
h = @. sigmoid_fast(output) * tanh_fast(c)
return h, (h, c)
end

Base.show(io::IO, m::LSTMCell) =
Expand Down Expand Up @@ -522,7 +514,8 @@ The arguments of the forward pass are:
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
Returns the tuple `(output, state)`, where `output = h'` and `state = h'`.
The new hidden state `h'` is an array of size `out` or `out x batch_size`.
# Examples
Expand All @@ -534,7 +527,7 @@ julia> h = zeros(Float32, 5); # hidden state
julia> x = rand(Float32, 3, 4); # in x batch_size
julia> h′ = g(x, h);
julia> y, h = g(x, h);
```
"""
struct GRUCell{I, H, V}
Expand Down Expand Up @@ -577,8 +570,8 @@ function (m::GRUCell)(x::AbstractVecOrMat, h)
r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2])
= @. tanh_fast(gxs[3] + r * ghs[3] + bs[3])
h = @. (1 - z) *+ z * h
return h
h = @. (1 - z) *+ z * h
return h, h
end

Base.show(io::IO, m::GRUCell) =
Expand Down Expand Up @@ -693,7 +686,8 @@ The arguments of the forward pass are:
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
Returns the tuple `(output, state)`, where `output = h'` and `state = h'`.
The new hidden state `h'` is an array of size `out` or `out x batch_size`.
"""
struct GRUv3Cell{I, H, V, HH}
Wi::I
Expand Down Expand Up @@ -736,8 +730,8 @@ function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2])
= tanh_fast.(gxs[3] .+ (m.Wh_h̃ * (r .* h)) .+ bs[3])
h = @. (1 - z) *+ z * h
return h
h = @. (1 - z) *+ z * h
return h, h
end

Base.show(io::IO, m::GRUv3Cell) =
Expand Down
18 changes: 10 additions & 8 deletions test/ext_common/recurrent_gpu_ad.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
out_from_state(state::Tuple) = state[1]
out_from_state(state) = state
cell_loss(cell, x, state) = mean(cell(x, state)[1])

function recurrent_cell_loss(cell, seq, state)
out = []
for xt in seq
state = cell(xt, state)
yt = out_from_state(state)
yt, state = cell(xt, state)
out = vcat(out, [yt])
end
return mean(stack(out, dims = 2))
Expand All @@ -18,7 +16,8 @@ end
h = zeros(Float32, d_out)
# Single Step
@test test_gradients(r, x[1], h; test_gpu=true,
compare_finite_diff=false) broken = :rnncell_single BROKEN_TESTS
compare_finite_diff=false,
loss=cell_loss) broken = :rnncell_single BROKEN_TESTS
# Multiple Steps
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff=false,
Expand Down Expand Up @@ -51,7 +50,7 @@ end
c = zeros(Float32, d_out)
# Single Step
@test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false,
loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single BROKEN_TESTS
loss = cell_loss) broken = :lstmcell_single BROKEN_TESTS
# Multiple Steps
@test test_gradients(cell, x, (h, c); test_gpu=true,
compare_finite_diff = false,
Expand Down Expand Up @@ -84,7 +83,9 @@ end
r = GRUCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single BROKEN_TESTS
@test test_gradients(r, x[1], h; test_gpu=true,
compare_finite_diff=false,
loss = cell_loss) broken = :grucell_single BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff = false,
loss = recurrent_cell_loss) broken = :grucell_multiple BROKEN_TESTS
Expand Down Expand Up @@ -116,7 +117,8 @@ end
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true,
compare_finite_diff=false) broken = :gruv3cell_single BROKEN_TESTS
compare_finite_diff=false,
loss=cell_loss) broken = :gruv3cell_single BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff=false,
loss = recurrent_cell_loss) broken = :gruv3cell_multiple BROKEN_TESTS
Expand Down
Loading

0 comments on commit 0683976

Please sign in to comment.