Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stacked RNN in Flux.jl? #2452

Closed
NeroBlackstone opened this issue Jun 3, 2024 · 1 comment · Fixed by #2549
Closed

Stacked RNN in Flux.jl? #2452

NeroBlackstone opened this issue Jun 3, 2024 · 1 comment · Fixed by #2549
Labels

Comments

@NeroBlackstone
Copy link

Motivation and description

How to build Stacked RNN in Flux.jl?

Is the following code the correct way?

using Flux
model = Chain(GRUv3(27 => 32),GRUv3(32 => 32),Dense(32 => 27))
Chain(
  Recur(
    GRUv3Cell(27 => 32),                # 5_792 parameters
  ),
  Recur(
    GRUv3Cell(32 => 32),                # 6_272 parameters
  ),
  Dense(32 => 27),                      # 891 parameters
)         # Total: 12 trainable arrays, 12_955 parameters,
          # plus 2 non-trainable, 64 parameters, summarysize 1.938 KiB.

There is no documentation mentioning this.

Possible Implementation

No response

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 12, 2024

Starting with Flux v0.15, a stacked RNN can be defined as follows:

stacked_rnn = Chain(LSTM(3 => 3), Dropout(0.5), LSTM(3 => 3))

If control of the initial states is also needed, define a custom struct:

struct StackedRNN{L,S}
    layers::L
    states0::S
end

function StackedRNN(d, num_layers)
    layers = [LSTM(d => d) for _ in num_layers]
    states0 = [Flux.initialstates(l) for l in layers]
    return StackedRNN(layers, states0)
end

function (m::StackedRNN)(x)
     for (layer, state0) in zip(rnn.layers, rnn.states0)
         x = layer(x, state) 
     end
     return x
end

This stuff should be documented in the guide
https://github.com/FluxML/Flux.jl/blob/master/docs/src/guide/models/recurrence.md

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants