diff --git a/docs/src/models/recurrence.md b/docs/src/models/recurrence.md index 456d250294..ad53ea9943 100644 --- a/docs/src/models/recurrence.md +++ b/docs/src/models/recurrence.md @@ -103,6 +103,13 @@ m(seq) # returns a new Seq of length 10 When we apply the model `m` to a seq, it gets mapped over every item in the sequence in order. This is just like the code above, but often more convenient. +You can get this behaviour more generally with the `Over` wrapper. + +```julia +m = Over(Dense(10,5)) +m(seq) # returns a new Seq of length 10 +``` + ## Truncating Gradients By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling `back!` will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive. diff --git a/src/Flux.jl b/src/Flux.jl index ec5238ff70..e641db3b52 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ module Flux using Juno using Lazy: @forward -export Chain, Dense, Seq, ChainSeq, RNN, LSTM, +export Chain, Dense, Seq, Over, RNN, LSTM, SGD, params using NNlib diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 08b19c926d..a9b75d976b 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -13,26 +13,14 @@ Seq(xs) = Seq(collect(xs)) Base.getindex(s::Seq, i) = s.data[i] -type ChainSeq - layers::Vector{Any} - ChainSeq(xs...) = new([xs...]) +struct Over{T} + m::T end -@forward ChainSeq.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push! -@forward ChainSeq.layers Base.start, Base.next, Base.done +(m::Over)(xs...) = m.m(xs...) +(m::Over)(xs::Seq) = Seq(map(m, xs.data)) -Optimise.children(c::ChainSeq) = c.layers - -(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers) -(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data]) - -Base.getindex(c::ChainSeq, i::AbstractArray) = Chain(c.layers[i]...) - -function Base.show(io::IO, c::ChainSeq) - print(io, "ChainSeq(") - join(io, c.layers, ", ") - print(io, ")") -end +Base.show(io::IO, m::Over) = print(io, "Over(", m.m, ")") # Stateful recurrence @@ -49,7 +37,7 @@ function (m::Recur)(xs...) return y end -(m::Recur)(s::Seq) = Seq([m(x) for x in s.data]) +(m::Recur)(s::Seq) = Seq(map(m, x.data)) Optimise.children(m::Recur) = (m.cell,)