Skip to content

Commit

Permalink
added initialstates
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Dec 5, 2024
1 parent 74c3a63 commit 9e1b1bb
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ end

@layer RNNCell

initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

Check warning on line 72 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L72

Added line #L72 was not covered by tests

function RNNCell(
(in, out)::Pair,
σ = tanh;
Expand All @@ -82,7 +84,10 @@ function RNNCell(
return RNNCell(σ, Wi, Wh, b)
end

(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1)))
function (rnn::RNNCell)(x::AbstractVecOrMat)
state = initialstates(rnn)
rnn(x, state)

Check warning on line 89 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L87-L89

Added lines #L87 - L89 were not covered by tests
end

function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -261,6 +266,10 @@ end

@layer LSTMCell

function initialstates(lstm:: LSTMCell)
return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2))

Check warning on line 270 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L269-L270

Added lines #L269 - L270 were not covered by tests
end

function LSTMCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -274,10 +283,9 @@ function LSTMCell(
return cell
end

function (m::LSTMCell)(x::AbstractVecOrMat)
h = zeros_like(x, size(m.Wh, 2))
c = zeros_like(h)
return m(x, (h, c))
function (lstm::LSTMCell)(x::AbstractVecOrMat)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))

Check warning on line 288 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L286-L288

Added lines #L286 - L288 were not covered by tests
end

function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
Expand Down Expand Up @@ -447,6 +455,8 @@ end

@layer GRUCell

initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2))

Check warning on line 458 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L458

Added line #L458 was not covered by tests

function GRUCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -459,7 +469,10 @@ function GRUCell(
return GRUCell(Wi, Wh, b)
end

(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
function (gru::GRUCell)(x::AbstractVecOrMat)
state = initialstates(gru)
return gru(x, state)

Check warning on line 474 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L472-L474

Added lines #L472 - L474 were not covered by tests
end

function (m::GRUCell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -603,6 +616,8 @@ end

@layer GRUv3Cell

initialstates(gru::GRUv3Cell) = zeros_like(gru.Wh, size(gru.Wh, 2))

Check warning on line 619 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L619

Added line #L619 was not covered by tests

function GRUv3Cell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -616,7 +631,10 @@ function GRUv3Cell(
return GRUv3Cell(Wi, Wh, b, Wh_h̃)
end

(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
function (gru::GRUv3Cell)(x::AbstractVecOrMat)
state = initialstates(gru)
return gru(x, state)

Check warning on line 636 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L634-L636

Added lines #L634 - L636 were not covered by tests
end

function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down

0 comments on commit 9e1b1bb

Please sign in to comment.