Skip to content

Commit

Permalink
transpose everything going into mxnet
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Mar 8, 2017
1 parent 3b004ba commit 9d1d176
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 50 deletions.
8 changes: 4 additions & 4 deletions src/backend/mxnet/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ node(x::mx.SymbolicNode) = x

graph(::typeof(tuple), args...) = (args...,)
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(*), x, W) = mx.dot(W, x) # Adjustments for batching
graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh)
Expand All @@ -38,10 +38,10 @@ graph(ctx::Context, d::Affine, x) =
register(ctx,
mx.FullyConnected(mx.SymbolicNode, data = x,
num_hidden = size(d.W.x, 2),
weight = var(ctx, AlterParam(d.W, false, false)),
bias = var(ctx, AlterParam(d.b, true, false))))
weight = var(ctx, AlterParam(d.W, x->x', nothing)),
bias = var(ctx, AlterParam(d.b, x->squeeze(x, 1), nothing))))

# TODO: use actual params}
# TODO: use actual params
graph(ctx::Context, c::Conv2D, x) =
mx.Convolution(x,
kernel = size(c.filter, 1, 2),
Expand Down
72 changes: 28 additions & 44 deletions src/backend/mxnet/model.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
using Flux: batchone, unbatchone, rebatch

# MNet batches on last dimension
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
rebatch_first(xs::Tuple) = rebatch_first.(xs)

paramvalue(p) = rebatch_last(p)
paramvalue(p::Flux.Param) = paramvalue(p.x)

# Basically a kludge to make Affine work
# Hopefully will go away with more inference
type AlterParam
param::Flux.Param
strip::Bool
rebatch::Bool
param
load
store
end

function paramvalue(p::AlterParam)
val = p.rebatch ? paramvalue(p.param) : p.param.x
p.strip ? squeeze(val, 1) : val
end
Base.size(p::AlterParam) = size(p.load(p.param.x))

type Graph
output
Expand All @@ -28,34 +15,32 @@ type Graph
end

function mxparams(g::Graph)
params = Dict{Symbol,mx.NDArray}()
params = Dict{Symbol,MXArray}()
for (name, param) in g.params
params[name] = mx.zeros(size(paramvalue(param)))
params[name] = MXArray(size(param))
end
return params
end

function loadparams!(g::Graph, args)
for (id, param) in g.params
haskey(args, id) && copy!(args[id], paramvalue(param))
function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])
end
end

function storeparams!(g::Graph, args)
for (id, param) in g.params
haskey(args, id) && copy!(param.x, rebatch_first(copy(args[id])))
end
end
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)

type Model <: Flux.Model
model::Any
graph::Graph
grads::Dict{Symbol,Any}
args::Dict{Symbol,MXArray}
grads::Dict{Symbol,MXArray}
outs::Vector{MXArray}
exec::mx.Executor
end

loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict)
storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict)
loadparams!(model::Model) = copyargs!(model.args, model.graph.params)
storeparams!(model::Model) = copyargs!(model.graph.params, model.args)

mxgroup(x) = x
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
Expand All @@ -64,35 +49,34 @@ mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)

function mxnet(model::Flux.Model, input)
graph = tograph(model, mx.Variable(:input))
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
model = @mxerr graph.stacks Model(model, graph, grads,
mx.bind(mxgroup(graph.output), args = args,
args_grad = grads,
grad_req = mx.GRAD_ADD))
args = merge(mxparams(graph), Dict(:input => MXArray(input)))
grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
exec = @mxerr graph.stacks mx.bind(mxgroup(graph.output),
args = ndparams(args),
args_grad = ndparams(grads),
grad_req = mx.GRAD_ADD)
model = Model(model, graph, args, grads, MXArray.(exec.outputs), exec)
loadparams!(model)
return model
end

function runmodel(model::Model, input)
copy!(model.exec.arg_dict[:input], input)
copy!(model.args[:input], input)
mx.forward(model.exec, is_train = true)
mxungroup(model.graph.output, copy(model.exec.outputs))
mxungroup(model.graph.output, copy(model.outs))
end

(m::Model)(x::Batch) = rebatch(rebatch_first(runmodel(m, rebatch_last(rawbatch(x)))))
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))

(m::Model)(x) = unbatchone(m(batchone(x)))

tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)

function runback!(model::Model, Δ)
model.grads[:input][:] = 0
mx.backward(model.exec, tond(Δ))
mx.backward(model.exec, MXArray(Δ).data)
copy(model.grads[:input])
end

Flux.back!(m::Model, Δ::Batch, x) = rebatch(rebatch_first(runback!(m, rebatch_last(rawbatch)))))
Flux.back!(m::Model, Δ::Batch, x) = rebatch(runback!(m, rawbatch(Δ)))

Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))

Expand Down Expand Up @@ -126,6 +110,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont
model = rewrite_softmax(model, label)
graph = tograph(model, mx.Variable(input), feedforward=true)
ff = mx.FeedForward(graph.output, context = context)
isempty(graph.params) || (ff.arg_params = mxparams(graph))
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
return ff
end
4 changes: 2 additions & 2 deletions test/backend/mxnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ Flux.loadmx()
xs = rand(20)
d = Affine(20, 10)

dm = mxnet(d, (20, 1))
dm = mxnet(d, (1, 20))
@test d(xs) dm(xs)

m = Multi(20, 15)
mm = mxnet(m, (20, 1))
mm = mxnet(m, (1, 20))
@test all(isapprox.(mm(xs), m(xs)))

@testset "Backward Pass" begin
Expand Down

0 comments on commit 9d1d176

Please sign in to comment.