Skip to content

Commit

Permalink
rm chainseq
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Sep 11, 2017
1 parent c80fb99 commit 7041ab9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 19 deletions.
7 changes: 7 additions & 0 deletions docs/src/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ m(seq) # returns a new Seq of length 10

When we apply the model `m` to a seq, it gets mapped over every item in the sequence in order. This is just like the code above, but often more convenient.

You can get this behaviour more generally with the `Over` wrapper.

```julia
m = Over(Dense(10,5))
m(seq) # returns a new Seq of length 10
```

## Truncating Gradients

By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling `back!` will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive.
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Flux
using Juno
using Lazy: @forward

export Chain, Dense, Seq, ChainSeq, RNN, LSTM,
export Chain, Dense, Seq, Over, RNN, LSTM,
SGD, params

using NNlib
Expand Down
24 changes: 6 additions & 18 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,14 @@ Seq(xs) = Seq(collect(xs))

Base.getindex(s::Seq, i) = s.data[i]

type ChainSeq
layers::Vector{Any}
ChainSeq(xs...) = new([xs...])
struct Over{T}
m::T
end

@forward ChainSeq.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
@forward ChainSeq.layers Base.start, Base.next, Base.done
(m::Over)(xs...) = m.m(xs...)
(m::Over)(xs::Seq) = Seq(map(m, xs.data))

Optimise.children(c::ChainSeq) = c.layers

(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])

Base.getindex(c::ChainSeq, i::AbstractArray) = Chain(c.layers[i]...)

function Base.show(io::IO, c::ChainSeq)
print(io, "ChainSeq(")
join(io, c.layers, ", ")
print(io, ")")
end
Base.show(io::IO, m::Over) = print(io, "Over(", m.m, ")")

# Stateful recurrence

Expand All @@ -49,7 +37,7 @@ function (m::Recur)(xs...)
return y
end

(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
(m::Recur)(s::Seq) = Seq(map(m, x.data))

Optimise.children(m::Recur) = (m.cell,)

Expand Down

0 comments on commit 7041ab9

Please sign in to comment.