Skip to content

Commit

Permalink
Merge pull request #2234 from darsnack/rnn-eltypes
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack authored Apr 22, 2023
2 parents 3abda8f + e3a44c7 commit bd392e4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 15 deletions.
8 changes: 4 additions & 4 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ end
RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {F,I,H,V,T}
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where {F,I,H,V,T}
Wi, Wh, b = m.Wi, m.Wh, m.b
_size_check(m, x, 1 => size(Wi,2))
σ = NNlib.fast_act(m.σ, x)
Expand Down Expand Up @@ -307,7 +307,7 @@ function LSTMCell((in, out)::Pair;
return cell
end

function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::AbstractVecOrMat) where {I,H,V,T}
_size_check(m, x, 1 => size(m.Wi,2))
b, o = m.b, size(h, 1)
xT = _match_eltype(m, T, x)
Expand Down Expand Up @@ -380,7 +380,7 @@ end
GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))

function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where {I,H,V,T}
_size_check(m, x, 1 => size(m.Wi,2))
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
xT = _match_eltype(m, T, x)
Expand Down Expand Up @@ -450,7 +450,7 @@ GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state =
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))

function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,HH,T}
function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where {I,H,V,HH,T}
_size_check(m, x, 1 => size(m.Wi,2))
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
xT = _match_eltype(m, T, x)
Expand Down
13 changes: 2 additions & 11 deletions src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ const nil = Nil()
Nil(::T) where T<:Number = nil
(::Type{T})(::Nil) where T<:Number = nil
Base.convert(::Type{Nil}, ::Number) = nil
Base.convert(::Type{T}, ::Nil) where {T<:Number} = zero(T)
Base.convert(::Type{Nil}, ::Nil) = nil

Base.float(::Type{Nil}) = Nil

Expand Down Expand Up @@ -157,17 +159,6 @@ for (fn, Dims) in ((:conv, DenseConvDims),)
end
end

# Recurrent layers: just convert to the type they like & convert back.

for Cell in [:RNNCell, :LSTMCell, :GRUCell, :GRUv3Cell]
@eval function (m::Recur{<:$Cell})(x::AbstractArray{Nil})
xT = fill!(similar(m.cell.Wi, size(x)), 0)
_, y = m.cell(m.state, xT) # discard the new state
return similar(x, size(y))
end
end


"""
@autosize (size...,) Chain(Layer(_ => 2), Layer(_), ...)
Expand Down

0 comments on commit bd392e4

Please sign in to comment.