Skip to content

Commit

Permalink
NewRecur experimental interface (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkschleg authored Aug 9, 2023
1 parent 62e7d06 commit 9dcae27
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658"
version = "0.1.3"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down
2 changes: 2 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ include("chain.jl")

include("compact.jl")

include("new_recur.jl")

end # module Fluxperimental
140 changes: 140 additions & 0 deletions src/new_recur.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import Flux: ChainRulesCore
import Compat: stack

##### Helper scan funtion which can likely be put into NNLib. #####
"""
scan_full
Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the full output of the sequence and the final carry. See `scan_partial` to only return the final output of the sequence.
"""
function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray})
# Recurrence operation used in the fold. Takes the state of the
# fold and the next input, returns the new state.
function recurrence_op((carry, outputs), input)
carry, out = func(carry, input)
return carry, vcat(outputs, [out])
end
# Fold left to right.
return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs)
end

function scan_full(func, init_carry, x_block)
# x_block is an abstractarray and we want to scan over the last dimension.
xs_ = Flux.eachlastdim(x_block)

# this is needed due to a bug in eachlastdim which produces a vector in a
# gradient context, but a generator otherwise.
xs = if xs_ isa Base.Generator
collect(xs_) # eachlastdim produces a generator in non-gradient environment
else
xs_
end
scan_full(func, init_carry, xs)
end

# Chain Rule for Base.mapfoldl_impl
function ChainRulesCore.rrule(
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(Base.mapfoldl_impl),
::typeof(identity),
op::G,
init,
x::Union{AbstractArray, Tuple};
) where {G}
hobbits = Vector{Any}(undef, length(x)) # Unfornately Zygote needs this
accumulate!(hobbits, x; init=(init, nothing)) do (a, _), b
c, back = ChainRulesCore.rrule_via_ad(config, op, a, b)
end
y = first(last(hobbits))
axe = axes(x)
project = ChainRulesCore.ProjectTo(x)
function unfoldl(dy)
trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
end
dop = sum(first, trio)
dx = map(last, Iterators.reverse(trio))
d_init = trio[end][2]
return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe)))
end
return y, unfoldl
end


"""
scan_partial
Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the final output of the sequence and the final carry. See `scan_full` to return the entire output sequence.
"""
function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray})
x_init, x_rest = Iterators.peel(xs)
(carry, y) = func(init_carry, x_init)
for x in x_rest
(carry, y) = func(carry, x)
end
carry, y
end

function scan_partial(func, init_carry, x_block)
# x_block is an abstractarray and we want to scan over the last dimension.
xs_ = Flux.eachlastdim(x_block)

# this is needed due to a bug in eachlastdim which produces a vector in a
# gradient context, but a generator otherwise.
xs = if xs_ isa Base.Generator
collect(xs_) # eachlastdim produces a generator in non-gradient environment
else
xs_
end
scan_partial(func, init_carry, xs)
end


"""
NewRecur
New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. This struct has two type parameters. The first `RET_SEQUENCE` is a boolean which determines whether `scan_full` (`RET_SEQUENCE=true`) or `scan_partial` (`RET_SEQUENCE=false`) is used to scan through the sequence. This structure has no internal state, and instead returns:
```julia
l = NewRNN(1,2)
xs # Some input array Input x BatchSize x Time
init_carry # the initial carry of the cell.
l(xs) # -> returns the output of the RNN, uses cell.state0 as init_carry.
l(init_carry, xs) # -> returns (final_carry, output), where the size ofoutput is determined by RET_SEQUENCE.
```
"""
struct NewRecur{RET_SEQUENCE, T}
cell::T
# state::S
function NewRecur(cell; return_sequence::Bool=false)
new{return_sequence, typeof(cell)}(cell)
end
function NewRecur{true}(cell)
new{true, typeof(cell)}(cell)
end
function NewRecur{false}(cell)
new{false, typeof(cell)}(cell)
end
end

Flux.@functor NewRecur
Flux.trainable(a::NewRecur) = (; cell = a.cell)
Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")")
NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence)

(l::NewRecur)(init_carry, x_mat::AbstractMatrix) = MethodError("Matrix is ambiguous with NewRecur")
(l::NewRecur)(init_carry, x_mat::AbstractVector{T}) where {T<:Number} = MethodError("Vector is ambiguous with NewRecur")

function (l::NewRecur)(xs::AbstractArray)
results = l(l.cell.state0, xs)
results[2] # Only return the output here.
end

function (l::NewRecur{false})(init_carry, xs)
results = scan_partial(l.cell, init_carry, xs)
results[1], results[2]
end

function (l::NewRecur{true})(init_carry, xs)
results = scan_full(l.cell, init_carry, xs)
results[1], stack(results[2], dims=3)
end
188 changes: 188 additions & 0 deletions test/new_recur.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
@testset "NewRecur RNN" begin
@testset "Forward Pass" begin
# tanh is needed for forward check to determine ordering of inputs.
cell = Flux.RNNCell(1, 1, tanh)
layer = Fluxperimental.NewRecur(cell; return_sequence=true)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = reshape([2.0f0, 3.0f0], 1, 1, 2)

# Lets make sure th output is correct
h = cell.state0
h, out = cell(h, [2.0f0])
h, out = cell(h, [3.0f0])

@test eltype(layer(x)) <: Float32
@test size(layer(x)) == (1, 1, 2)
@test layer(x)[1, 1, 2] out[1,1]

@test length(layer(cell.state0, x)) == 2 # should return a tuple. Maybe better test is needed.
@test layer(cell.state0, x)[2][1,1,2] out[1,1]

@test_throws MethodError layer([2.0f0])
@test_throws MethodError layer([2.0f0;; 3.0f0])
end

@testset "gradients-implicit" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true)
ps = Flux.params(nm_layer)
x_block = reshape(vcat(x...), 1, 1, length(x))
e, g = Flux.withgradient(ps) do
out = nm_layer(x_block)
sum(out[1, 1, 2])
end

@test primal[1] e
@test ∇Wi g[ps[1]]
@test ∇Wh g[ps[2]]
@test ∇b g[ps[3]]
@test ∇state0 g[ps[4]]
end

@testset "gradients-explicit" begin

cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2


x_block = reshape(vcat(x...), 1, 1, length(x))
nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true)
e, g = Flux.withgradient(nm_layer) do layer
out = layer(x_block)
sum(out[1, 1, 2])
end
grads = g[1][:cell]

@test primal[1] e
@test ∇Wi grads[:Wi]
@test ∇Wh grads[:Wh]
@test ∇b grads[:b]
@test ∇state0 grads[:state0]
end
end

@testset "New Recur RNN Partial Sequence" begin
@testset "Forward Pass" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Fluxperimental.NewRecur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = reshape([2.0f0, 3.0f0], 1, 1, 2)

h = cell.state0
h, out = cell(h, [2.0f0])
h, out = cell(h, [3.0f0])

@test eltype(layer(x)) <: Float32
@test size(layer(x)) == (1, 1)
@test layer(x)[1, 1] out[1,1]

@test length(layer(cell.state0, x)) == 2
@test layer(cell.state0, x)[2][1,1] out[1,1]

@test_throws MethodError layer([2.0f0])
@test_throws MethodError layer([2.0f0;; 3.0f0])
end

@testset "gradients-implicit" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false)
ps = Flux.params(nm_layer)
x_block = reshape(vcat(x...), 1, 1, length(x))
e, g = Flux.withgradient(ps) do
out = (nm_layer)(x_block)
sum(out)
end

@test primal[1] e
@test ∇Wi g[ps[1]]
@test ∇Wh g[ps[2]]
@test ∇b g[ps[3]]
@test ∇state0 g[ps[4]]
end

@testset "gradients-explicit" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

x_block = reshape(vcat(x...), 1, 1, length(x))
nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false)
e, g = Flux.withgradient(nm_layer) do layer
out = layer(x_block)
sum(out)
end
grads = g[1][:cell]

@test primal[1] e
@test ∇Wi grads[:Wi]
@test ∇Wh grads[:Wh]
@test ∇b grads[:b]
@test ∇state0 grads[:state0]

end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ using Flux, Fluxperimental

include("compact.jl")

include("new_recur.jl")

end

0 comments on commit 9dcae27

Please sign in to comment.