Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I integrate an RNN into an ODEProblem in Lux.jl? #952

Open
disadone opened this issue Sep 28, 2024 · 6 comments
Open

How can I integrate an RNN into an ODEProblem in Lux.jl? #952

disadone opened this issue Sep 28, 2024 · 6 comments
Labels
question Further information is requested

Comments

@disadone
Copy link

Hi!
Just wondering how the RNN could be mixed into the ODEProblem

In flux times, it seems a Recur layer need to be created. However there is already a Recurrence in Lux.jl
Training of UDEs with recurrent networks

How can Lux.jl do the job now?
I self defined a GRUcell and it runs well combined with the beginner tutorial Training a Simple LSTM

using ConcreteStructs: @concrete
using Lux
using Static
using Random

IntegerType = Union{Integer,Static.StaticInteger}
BoolType = Union{StaticBool, Bool, Val{true},Val{false}}

@concrete struct FastGRUCell <:Lux.AbstractRecurrentCell
    train_state <: StaticBool
    in_dims <: IntegerType
    out_dims <: IntegerType
    init_bias
    init_weight
    init_state
    dynamics_nonlinearity
    gating_nonlinearity
    α<:AbstractFloat
    layernormQ::StaticBool
end

function FastGRUCell(
        (in_dims,out_dims)::Pair{<:Lux.IntegerType,<:Lux.IntegerType},
        Δt::T, τ::T,layernormQ::BoolType;
        train_state::BoolType=False(),
        init_weight=Lux.glorot_normal,
        init_bias=Lux.zeros32,
        init_state=zeros32,
        dynamics_nonlinearity = Lux.sigmoid_fast,
        gating_nonlinearity = Lux.tanh_fast) where T<:AbstractFloat
    init_weight = ntuple(Returns(init_weight),3)
    init_bias = ntuple(Returns(init_bias),3)
    α = Δt/τ
    return FastGRUCell(
        static(train_state),
        in_dims,out_dims,init_bias,init_weight,init_state,
        dynamics_nonlinearity,gating_nonlinearity,α,static(layernormQ)
    )
end

function Lux.initialparameters(rng::AbstractRNG,gru::FastGRUCell)
    # hidden to hidden
    Wz,Wr,Wh = (Lux.init_rnn_weight(
        rng,init_weight,gru.out_dims,(gru.out_dims,gru.out_dims)) for init_weight in gru.init_weight)
    # input to hidden
    Uz,Ur,Uh = (Lux.init_rnn_weight(
        rng,init_weight,gru.out_dims,(gru.out_dims,gru.in_dims)) for init_weight in gru.init_weight)

    ps = (; Wz,Wr,Wh,Uz,Ur,Uh)

    biasz,biasr,biash = (Lux.init_rnn_weight(rng,init_bias,gru.out_dims,gru.out_dims) for init_bias in gru.init_bias)

    ps = merge(ps, (; biasz,biasr,biash))
    Lux.has_train_state(gru) &&  (ps = merge(ps, (hidden_state=gru.init_state(rng, gru.out_dims),)))
    return ps
end
Lux.initialstates(rng::AbstractRNG,::FastGRUCell) = (rng=Lux.Utils.sample_replicate(rng),)

function (gru::FastGRUCell{True})(x::AbstractMatrix,ps,st::NamedTuple)
    hidden_state = Lux.init_trainable_rnn_hidden_state(ps.hidden_state, x)
    return gru((x, (hidden_state,)), ps, st)
end

function (gru::FastGRUCell{False})(x::AbstractMatrix, ps, st::NamedTuple)
    rng = Lux.replicate(st.rng)
    st = merge(st, (; rng))
    hidden_state = Lux.init_rnn_hidden_state(rng, gru, x)
    return gru((x, (hidden_state,)), ps, st)
end

const _FastGRUCellInputType = Tuple{
    <:AbstractMatrix, Tuple{<:AbstractMatrix}}

function (m::FastGRUCell)(
    (x,(h,))::_FastGRUCellInputType, ps,st::NamedTuple)

    Wzh =  fused_dense_bias_activation(identity,ps.Wz,h,ps.biasz)
    Wrh =  fused_dense_bias_activation(identity,ps.Wr,h,ps.biasr)
    Uzx =  fused_dense_bias_activation(identity,ps.Uz,x,nothing)
    Urx =  fused_dense_bias_activation(identity,ps.Ur,x,nothing)
    
    z = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wzh,nothing,nothing) .+ Uzx)) : (@. m.gating_nonlinearity(Wzh+Uzx))
    r = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wrh,nothing,nothing) .+ Urx)) : (@. m.gating_nonlinearity(Wrh+Urx))

    Whh = fused_dense_bias_activation(identity,ps.Wh, h .* r ,ps.biash)
    Uhh = fused_dense_bias_activation(identity,ps.Uh, x ,nothing)
    h̃ = dynamic(m.layernormQ) ? (m.dynamics_nonlinearity.(layernorm(Whh,nothing,nothing) .+ Uhh)) : (@. m.dynamics_nonlinearity(Whh+Uhh))
    h′ = @. (1-m.α * z) * h + m.α * z *return (h′,(h′,)),st
end


# --------------------------------------------------------------------------------------------------
# adapted from https://lux.csail.mit.edu/stable/tutorials/beginner/3_SimpleRNN#Creating-a-Classifier
using Lux, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end

struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:fastgru_cell, :classifier)}
    fastgru_cell::L
    classifier::C
end
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        FastGRUCell(in_dims => hidden_dims, 0.01f0, 1.0f0, true), 
        Dense(hidden_dims => out_dims, sigmoid))
end

function (s::SpiralClassifier)(
    x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}

    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_fastgru = s.fastgru_cell(x_init, ps.fastgru_cell, st.fastgru_cell)

    for x in x_rest
        (y, carry), st_fastgru = s.fastgru_cell((x, carry), ps.fastgru_cell, st_fastgru)
    end

    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    st = merge(st, (classifier=st_classifier, fastgru_cell = st_fastgru))

    return vec(y), st
end

 # ----- loss
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

# ----- training

function main(model_type)
    dev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)

    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            # x: (2,50,128), y: (128,)  # dimension time trials
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)

When I try to transfer my self-defined GRUcell to the tutorial MNIST Classification using Neural ODEs, I don't know how to start the job.
Really appreciate If anyone could help me!
Thanks!

@disadone disadone added the question Further information is requested label Sep 28, 2024
@disadone disadone changed the title How I integrate an RNN into an ODEProblem in Lux.jl? How can I integrate an RNN into an ODEProblem in Lux.jl? Sep 28, 2024
@ChrisRackauckas
Copy link
Member

The question does not make much sense. Having hidden state which is carried over to the next call makes the equation not an ODE and thus not convergent. If you do what you do here where you init the hidden state on each call, this model is equivalent to just calling the NN that is supposed to be recurrant, and so you might as well call that NN directly. So I don't quite get what you're trying to do?

@disadone
Copy link
Author

disadone commented Oct 6, 2024

Sorry for confusing. I would like to train a sequence-to-sequence model where the RNN could first derive a series of values and they are then fed into a sequence-to-sequence neuralode stuff as the inhomogeneous equation input. The weight in RNN and parameters in neuralode are trained together.

Maybe the question can be simplified as "How can I train a sequence-to-sequence neuralode with a series of inputs ?"

@ChrisRackauckas
Copy link
Member

Maybe @avik-pal has an example

@avik-pal
Copy link
Member

The weight in RNN and parameters in neuralode are trained together.

Do you mean the RNN weights and the neural network weights are shared?

@disadone
Copy link
Author

The weight in RNN and parameters in neuralode are trained together.

Do you mean the RNN weights and the neural network weights are shared?

No, I mean the output of RNN could be the input of neualode at each time point.

@ChrisRackauckas
Copy link
Member

But without state? Then it's not an RNN?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants