Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature requests for RNNs #2514

Open
2 of 9 tasks
CarloLucibello opened this issue Nov 4, 2024 · 9 comments
Open
2 of 9 tasks

feature requests for RNNs #2514

CarloLucibello opened this issue Nov 4, 2024 · 9 comments
Labels

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Nov 4, 2024

After the redesign in #2500, here is a list of potential improvements for recurrent layers and recurrent cells

@MartinuzziFrancesco
Copy link
Contributor

Could it be possible to add to the list the option to use different initializers for the input matrix and recurrent matrix? This is provided by both Keras/TF and Flax.

This should be as straightforward as

function RNNCell((in, out)::Pair, σ=relu;
    kernel_init = glorot_uniform,
    recurrent_kernel_init = glorot_uniform,
    bias = true)
    Wi = kernel_init(out, in)
    U = recurrent_kernel_init(out, out)
    b = create_bias(Wi, bias, size(Wi, 1))
    return RNNCell(σ, Wi, U, b)
end

I can also open a quick PR on this if needed

@CarloLucibello
Copy link
Member Author

yes! PR welcome

@MartinuzziFrancesco
Copy link
Contributor

Following up on this, should we also have an option to choose the init for the bias?

@CarloLucibello
Copy link
Member Author

We don't do it for feedforward layers, if someone wants a non-zero bias can just change it manually in the constructor, layer.bias .= ...

@MartinuzziFrancesco
Copy link
Contributor

Ehi, a couple of questions on this features request again. Would the initialstates function be like the current function (rnn::RNN)(x::AbstractVecOrMat) but returning just the state?

would something simple like

function initialstates(rnn::RNNCell; init_state = zeros)
    state = init_state(size(rnn.Wh, 2))
    return state
end

function initialstates(lstm::LSTMCell; init_state = zeros, init_cstate = zeros)
    state = init_state(size(lstm.Wh, 2))
    cstate = init_cstate(size(lstm.Wh, 2))
    return state, cstate
end

suffice or were you looking for something more? Maybe more control on the type would be needed

@CarloLucibello
Copy link
Member Author

I would just have

initialstates(c::RNNCell) = zeros_like(c.Wh, size(c.Wh, 2)))
initialstates(c:: LSTMCell) = zeros_like(c.Wh, size(c.Wh, 2)), zeros_like(c.Wh, size(c.Wh, 2))

If different initializations are needed, we could add an init_state to the constructor, but maybe better let the user handle it as a part of the model, for simplicity and flexibility (e.g. making the initial state trainable). I don't have a strong opinion though.

@MartinuzziFrancesco
Copy link
Contributor

so this way we would simply do

function (rnn::RNNCell)(inp::AbstractVecOrMat)
    state = initialstates(rnn)
    return rnn(inp, state)
end

to keep compatibility for the current version, right?

I think your point is good, additionally no other library provides a specific init for the state so that's probably a little overkill. I'll push something along these lines later

@MartinuzziFrancesco
Copy link
Contributor

Trying to tackle adding num_layers and dropout, with Bidirectional as well once I figured out the general direction. I just wanted to ask if the current approach is in line with what you had in mind:

struct TestRNN{A, B}
    cells::A
    dropout_layer::B
end

Flux.@layer :expand TestRNN

function TestRNN((in_size, out_size)::Pair;
    n_layers::Int=1,
    dropout::Float64=0.0,
    kwargs...)
    cells = []

    for i in 1:n_layers
        tin_size = i == 1 ? in_size : out_size
        push!(cells, RNNCell(tin_size => out_size; kwargs...))
    end

    if dropout > 0.0
        dropout_layer = Dropout(dropout)
    else
        dropout_layer = nothing
    end

    return TestRNN(cells, dropout_layer)
end

function (rnn::TestRNN)(inp, state)
    @assert ndims(inp) == 2 || ndims(inp) == 3
    output = []
    num_layers = length(rnn.cells)

    for inp_t in eachslice(inp, dims=2)
        new_states = []
        for (idx_cell, cell) in enumerate(rnn.cells)
            new_state = cell(inp_t, state[:, idx_cell])
            new_states = vcat(new_states, [new_state])
            inp_t = new_state

            if rnn.dropout_layer isa Dropout && idx_cell < num_layers - 1
                inp_t = rnn.dropout_layer(inp_t)
            end
        end
        state = stack(new_states)
        output = vcat(output, [inp_t])
    end
    output = stack(output, dims=2)
    return output, state
end

@CarloLucibello
Copy link
Member Author

I think we don't need this additional complexity. A simple stacked rnn can be constructed as a chain:

stacked_rnn = Chain(LSTM(3 => 3), Dropout(0.5), LSTM(3 => 3))

If control of the initial states is also needed, it is not hard to define a custom struct the job:

struct StackedRNN{L,S}
    layers::L
    states0::S
end

function StackedRNN(d, num_layers)
    layers = [LSTM(d => d) for _ in num_layers]
    states0 = [Flux.initialstates(l) for l in layers]
    return StackedRNN(layers, states0)
end

function (m::StackedRNN)(x)
   for (layer, state0) in zip(rnn.layers, rnn.states0)
       x = layer(x, state0) 
   end
   return x
end

I think it is enough to document this in the guide
https://github.com/FluxML/Flux.jl/blob/master/docs/src/guide/models/recurrence.md

This was referenced Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants