Skip to content

Commit

Permalink
fixed small errors
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Dec 7, 2024
1 parent a9bc95f commit d54dcd4
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,14 @@ end

@layer RNN

initialstates(rnn::RNN) = zeros_like(x, size(rnn.cell.Wh, 1))
initialstates(rnn::RNN) = zeros_like(rnn.cell.Wh, size(rnn.cell.Wh, 1))

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
return RNN(cell)
end

function (m::RNN)(x::AbstractArray)
function (rnn::RNN)(x::AbstractArray)
state = initialstates(rnn)
return rnn(x, state)
end
Expand Down Expand Up @@ -410,7 +410,7 @@ end
@layer LSTM

function initialstates(lstm::LSTM)
state = zeros_like(x, size(lstm.cell.Wh, 2))
state = zeros_like(lstm.cell.Wh, size(lstm.cell.Wh, 2))
cstate = zeros_like(state)
return state, cstate
end
Expand Down Expand Up @@ -590,7 +590,7 @@ end

@layer GRU

initialstates(gru::GRU) = zeros_like(x, size(gru.cell.Wh, 2))
initialstates(gru::GRU) = zeros_like(gru.cell.Wh, size(gru.cell.Wh, 2))

function GRU((in, out)::Pair; cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
Expand Down Expand Up @@ -737,7 +737,7 @@ end

@layer GRUv3

initialstates(gru::GRUv3) = zeros_like(x, size(gru.cell.Wh, 2))
initialstates(gru::GRUv3) = zeros_like(gru.cell.Wh, size(gru.cell.Wh, 2))

function GRUv3((in, out)::Pair; cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
Expand Down

0 comments on commit d54dcd4

Please sign in to comment.