Skip to content

Commit

Permalink
nicer mxnet api
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Mar 8, 2017
1 parent 9c9feb9 commit 854a1e1
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 53 deletions.
2 changes: 1 addition & 1 deletion docs/src/apis/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ directly won't have great performance. In order to run a computationally intensi
This is easy to do. Just call either `mxnet` or `tf` on a model to convert it to a model of that kind:

```julia
mxmodel = mxnet(model, (10, 1))
mxmodel = mxnet(model)
mxmodel(xs) #> [0.0650, 0.0655, ...]
# or
tfmodel = tf(model)
Expand Down
16 changes: 9 additions & 7 deletions docs/src/models/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@ end

model = TLP(Affine(10, 20), Affine(21, 15))

mxmodel = mxnet(model, (10, 1))
mxmodel = mxnet(model)

mxmodel(rand(10))
```

Unfortunately, this model has a (fairly obvious) typo, which means that the code above won't run. Instead we get an error message:

```julia
InferShape Error in dot5: [20:37:39] src/operator/./matrix_op-inl.h:271:
Check failed: (lshape[1]) == (rshape[0]) dot shape error: (15,21) X (20,1)
in Flux.Affine at affine.jl:8
in TLP at test.jl:6
in mxnet(::TLP, ::Tuple{Int64,Int64}) at model.jl:40
in mxnet(::TLP, ::Vararg{Any,N} where N) at backend.jl:20
Error in operator dot2: [21:28:21] src/operator/tensor/./matrix_op-inl.h:460:
Check failed: lshape[1] == rshape[0] (20 vs. 21) dot shape error: (1,20) X (21,15)
Flux.Affine at affine.jl:8
TLP at basic.jl:6
(::Flux.MX.Model)(::Flux.Batch{Array{Float64,1},Array{Float64,2}}) at model.jl:105
(::Flux.MX.Model)(::Array{Float64,1}) at model.jl:107
```

Most frameworks would only give the error message here – not so helpful if you have thousands of nodes in your computational graph. However, Flux is able to give good error reports *even when no Julia code has been run*, e.g. when running on a backend like MXNet. This enables us to pinpoint the source of the error very quickly even in a large model.
Expand Down
4 changes: 2 additions & 2 deletions src/backend/backend.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function loadmx()
@eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl"))
end

function mxnet(args...)
function mxnet(m)
loadmx()
eval(:(MX.mxnet($(args...))))
eval(:(MX.mxnet($m)))
end
106 changes: 66 additions & 40 deletions src/backend/mxnet/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ end

Base.size(p::AlterParam) = size(p.load(p.param.x))

function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])
end
end

type Graph
output
params::Dict{Symbol,Any}
Expand All @@ -22,75 +28,95 @@ function mxparams(g::Graph)
return params
end

function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])
end
end

ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)

type Model <: Flux.Model
model::Any
type Exec <: Flux.Model
graph::Graph
exec::mx.Executor
args::Dict{Symbol,MXArray}
grads::Dict{Symbol,MXArray}
outs::Vector{MXArray}
exec::mx.Executor
end

loadparams!(model::Model) = copyargs!(model.args, model.graph.params)
storeparams!(model::Model) = copyargs!(model.graph.params, model.args)
loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params)
storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args)

mxgroup(x) = x
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
mxungroup(x, outs) = copy(shift!(outs))
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)

function mxnet(model::Flux.Model, input)
graph = tograph(model, mx.Variable(:input))
function executor(graph::Graph, input)
args = merge(mxparams(graph), Dict(:input => MXArray(input)))
grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
exec = @mxerr graph.stacks mx.bind(mxgroup(graph.output),
args = ndparams(args),
args_grad = ndparams(grads),
grad_req = mx.GRAD_ADD)
model = Model(model, graph, args, grads, MXArray.(exec.outputs), exec)
loadparams!(model)
return model
exec = mx.bind(mxgroup(graph.output),
args = ndparams(args),
args_grad = ndparams(grads),
grad_req = mx.GRAD_ADD)
exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs))
loadparams!(exec)
return exec
end

function runmodel(model::Model, input)
copy!(model.args[:input], input)
mx.forward(model.exec, is_train = true)
mxungroup(model.graph.output, copy(model.outs))
function (exec::Exec)(input)
copy!(exec.args[:input], input)
mx.forward(exec.exec, is_train = true)
mxungroup(exec.graph.output, copy(exec.outs))
end

(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))

(m::Model)(x) = unbatchone(m(batchone(x)))

function runback!(model::Model, Δ)
model.grads[:input][:] = 0
mx.backward(model.exec, MXArray(Δ).data)
copy(model.grads[:input])
function Flux.back!(exec::Exec, Δ)
exec.grads[:input][:] = 0
mx.backward(exec.exec, MXArray(Δ).data)
copy(exec.grads[:input])
end

Flux.back!(m::Model, Δ::Batch, x) = rebatch(runback!(m, rawbatch(Δ)))

Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))

function Flux.update!(model::Model, η)
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
function Flux.update!(exec::Exec, η)
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
mx.@nd_as_jl rw = (arg, grad) begin
arg .-= grad .* η
grad[:] = 0
end
end
storeparams!(model)
return model
storeparams!(exec)
return exec
end

# TODO: if `last` changes, update params appropriately

type Model
model::Any
graph::Graph
execs::Dict{Tuple,Exec}
last::Exec
Model(model, graph, execs) = new(model, graph, execs)
end

function mxnet(model)
graph = tograph(model, mx.Variable(:input))
Model(model, graph, Dict())
end

import Base: @get!

executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))

function (m::Model)(x::Batch)
x′ = rawbatch(x)
m.last = exec = @mxerr m.graph.stacks executor(m, size(x′))
rebatch(exec(x′))
end

(m::Model)(x) = unbatchone(m(batchone(x)))

function Flux.back!(m::Model, Δ::Batch, x::Batch)
m.last = exec = m.execs[size(rawbatch(x))]
rebatch(back!(exec, rawbatch(Δ)))
end

Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), batchone(x)))

Flux.update!(m::Model, η) = (update!(m.last, η); m)

# MX FeedForward interface

type SoftmaxOutput
Expand Down
7 changes: 4 additions & 3 deletions test/backend/mxnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ Flux.loadmx()
xs = rand(20)
d = Affine(20, 10)

dm = mxnet(d, (1, 20))
dm = mxnet(d)
@test d(xs) dm(xs)

m = Multi(20, 15)
mm = mxnet(m, (1, 20))
mm = mxnet(m)
@test all(isapprox.(mm(xs), m(xs)))

@testset "Backward Pass" begin
Expand Down Expand Up @@ -40,7 +40,8 @@ end
@testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal")
e = try mxnet(model, (10, 1))
dm = mxnet(model)
e = try dm(rand(10))
catch e e end

@test isa(e, DataFlow.Interpreter.Exception)
Expand Down

0 comments on commit 854a1e1

Please sign in to comment.