Skip to content

Commit

Permalink
Workaround Flux#1027 (#4)
Browse files Browse the repository at this point in the history
* add example (broken)

* add Flux workarounds

* add tests

* oops! actually use BatchNorm in example

* add comments to make example more helpful

* Add Random dependency to test

* add comment to address review

* switch to `fcollect`-based implementation

* rename file, clean up a little

* update comment

* simplify API: `collect` not needed

* update compat

* add Flux compat

* add link to issue

* Apply suggestions from code review

Co-authored-by: Alex Arslan <[email protected]>

* better errors

* rename `loadweights!` to `load_weights!`

* use stablerng

* rename `weights` function to `fetch_weights`

* add terminology note

Co-authored-by: Hannah Robertson <[email protected]>
Co-authored-by: Alex Arslan <[email protected]>
  • Loading branch information
3 people authored Jul 8, 2021
1 parent 002298b commit cd965c8
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
/Manifest.toml
Manifest.toml
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
name = "LegolasFlux"
uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4"
authors = ["Beacon Biosignals, Inc."]
version = "0.1.0"
version = "0.1.1"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Arrow = "1"
Flux = "0.12"
Functors = "0.2.1"
Legolas = "0.1, 0.2"
Tables = "1"
julia = "1.5"

[extras]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Flux"]
test = ["Test", "Flux", "StableRNGs", "Statistics", "Random"]
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
[![codecov](https://codecov.io/gh/beacon-biosignals/LegolasFlux.jl/branch/main/graph/badge.svg?token=NHYUL22HCC)](https://codecov.io/gh/beacon-biosignals/LegolasFlux.jl)

LegolasFlux provides some simple functionality to use [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/)'s
extensible Arrow schemas as means to serialize Flux models using Flux's `params` and `loadparams!`.
extensible Arrow schemas as means to serialize Flux models similarly to using Flux's `params` and `loadparams!`
(instead, we export similar functions `fetch_weights` and `load_weights!` which handle layers like `BatchNorm` correctly for this purpose).

The aim is to serialize only the numeric weights, *not* the code defining the model. This is a very different approach
from e.g. BSON.jl, and hopefully much more robust.
from e.g. BSON.jl, and hopefully much more robust. Note that in this package, we use `weights` to refer to the numeric arrays that are modified over the course of training a model; that includes biases as well as means and variances in e.g. BatchNorms (but not e.g. configuration settings).

With this approach, however, if you change the code such that the weights are no longer valid (e.g. add a layer),
you will not be able to load back the same model.
Expand All @@ -28,16 +29,16 @@ my_model = make_my_model()
using LegolasFlux

# We can save whatever other columns we'd like to as well as the `weights`.
model_row = ModelRow(; weights = collect(params(cpu(my_model))), architecture_version = 1, loss = 0.5)
model_row = ModelRow(; weights = fetch_weights(cpu(my_model)),
architecture_version=1, loss=0.5)
write_model_row("my_model.model.arrow", model_row)

# Great! Later on, we want to re-load our model weights.
fresh_model = make_my_model()

model_row = read_model_row("my_model.model.arrow")
Flux.loadparams!(fresh_model, collect(model_row.weights))
# Now our params have been loaded back into `fresh_model`.
# Note we needed to `collect` the weights before we use them.
load_weights!(fresh_model, model_row.weights)
# Now our weights have been loaded back into `fresh_model`.

# We can also check out our other columns:
model_row.loss # 0.5
Expand All @@ -47,6 +48,8 @@ model_row.loss # 0.5
We can make use of the `architecture_version` column to specify a version number for the architectures, in order
to keep track of for which architectures the weights are valid for.

See [examples/digits.jl](examples/digits.jl) for a larger example.

## `LegolasFlux.ModelRow`

A `LegolasFlux.ModelRow` is the central object of LegolasFlux. It acts as a Tables.jl-compatible row that can store the weights
Expand Down Expand Up @@ -78,4 +81,5 @@ one might name files produced by this row as e.g. `training_run.digits.model.arr
Note in this example the schema is called `digits.model` instead of just say `digits`, since the package Digits might want to
create other Legolas schemas as well at some point.

Check out the [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/) repo to see more about how its extensible schema system works.
Check out the [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/) repo to see more about how its extensible schema system works,
and the example at [examples/digits.jl](examples/digits.jl).
7 changes: 7 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd"
LegolasFlux = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
140 changes: 140 additions & 0 deletions examples/digits.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Model modified from
# https://discourse.julialang.org/t/how-to-drop-the-dropout-layers-in-flux-jl-when-assessing-model-performance/19924

using Flux, Statistics, Random, Test
# Uncomment to use MNIST data
# using MLDatasets: MNIST
using StableRNGs
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Legolas, LegolasFlux

# This should store all the information needed
# to construct the model.
Base.@kwdef struct DigitsConfig
seed::Int = 5
dropout_rate::Float32 = 0f1
end

# Here's our model object itself, just a `DigitsConfig` and
# a `chain`. We keep the config around so it's easy to save out
# later.
struct DigitsModel
chain::Chain
config::DigitsConfig
end

# Ensure Flux can recurse into our model to find params etc
Flux.@functor DigitsModel (chain,)

# Construct the actual model from a config object. This is the only
# constructor that should be used, to ensure the model is created just
# from the config object alone.
function DigitsModel(config::DigitsConfig=DigitsConfig())
dropout_rate = config.dropout_rate
Random.seed!(config.seed)
chain = Chain(Dropout(dropout_rate),
Conv((3, 3), 1 => 32, relu),
BatchNorm(32, relu),
MaxPool((2, 2)),
Dropout(dropout_rate),
Conv((3, 3), 32 => 16, relu),
Dropout(dropout_rate),
MaxPool((2, 2)),
Dropout(dropout_rate),
Conv((3, 3), 16 => 10, relu),
Dropout(dropout_rate),
x -> reshape(x, :, size(x, 4)),
Dropout(dropout_rate),
Dense(90, 10),
softmax)
return DigitsModel(chain, config)
end

# Our model acts on input just by applying the chain.
(m::DigitsModel)(x) = m.chain(x)

# Here, we define a schema extension of the `legolas-flux.model` schema.
# We add our `DigitsConfig` object, as well as the epoch and accuracy.
const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1",
config::DigitsConfig,
epoch::Union{Missing, Int},
accuracy::Union{Missing, Float32})

# Construct a `DigitsRow` from a model by collecting the weights.
# This can then be saved with e.g. `LegolasFlux.write_model_row`.
function DigitsRow(model::DigitsModel; epoch=missing, accuracy=missing)
return DigitsRow(; weights=fetch_weights(model), model.config, epoch, accuracy)
end

# Construct a `DigitsModel` from a row satisfying the `DigitsRow` schema,
# i.e. one with a `weights` and `config::DigitsConfig`.
# This could be the result of `LegolasFlux.read_model_row`.
function DigitsModel(row)
m = DigitsModel(row.config)
load_weights!(m, row.weights)
return m
end


# Increase to get more training/test data
N_train = 1_000
N_test = 50

##
# to use MNIST data, uncomment these
# train_x, train_y = MNIST.traindata(Float32, 1:N_train)
# test_x, test_y = MNIST.testdata(Float32, 1:N_test)

# Random data:
rng = StableRNG(735)
train_x = rand(rng, Float32, 28, 28, N_train)
train_y = rand(rng, 0:9, N_train)
test_x = rand(rng, Float32, 28, 28, N_test)
test_y = rand(rng, 0:9, N_test)
##

# Partition into batches of size 32
batch_size = 32
train = [(reshape(train_x[:, :, I], 28, 28, 1, :), onehotbatch(train_y[I], 0:9))
for I in partition(1:N_train, batch_size)]

tX = reshape(test_x, 28, 28, 1, :)
tY = onehotbatch(test_y, 0:9)

function accuracy(m, x, y)
testmode!(m)
val = mean(onecold(m(x)) .== onecold(y))
trainmode!(m)
return val
end

function train_model!(m; N = N_train)
loss = (x, y) -> crossentropy(m(x), y)
opt = ADAM()
evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5)
Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt; cb=evalcb)
return accuracy(m, tX, tY)
end

m = DigitsModel()

# increase N to actually train more than a tiny amount
acc = train_model!(m; N=10)

# Let's serialize out the weights into a `DigitsRow`.
# We could save this here with `write_model_row`.
row = DigitsRow(m; epoch=1, accuracy=acc)

testmode!(m)
input = tX[:, :, :, 1:1]
output = m(input)
label = tY[:, 1]

# Let's now reconstruct the model from the `row` and check that we get
# the same outputs.
m2 = DigitsModel(row)
testmode!(m2)
output2 = m2(input)

@test output output2
4 changes: 4 additions & 0 deletions src/LegolasFlux.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
module LegolasFlux

export write_model_row, read_model_row
export fetch_weights, load_weights!

using Legolas
using Arrow
using Arrow.ArrowTypes
using Tables
using Functors
using Base: IdSet

const LEGOLAS_SCHEMA = Legolas.Schema("legolas-flux.model@1")

Expand Down Expand Up @@ -110,5 +113,6 @@ function read_model_row(io_or_path)
return only(rows)
end

include("functors.jl")

end # module
45 changes: 45 additions & 0 deletions src/functors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Modified version of `fcollect` to use an `IdSet` cache so that
# distinct arrays whose values happen to be duplicates are each kept.
# <https://github.com/FluxML/Functors.jl/issues/16>
function fcollect2(x; output=[], cache=IdSet(), exclude=_ -> false)
x in cache && return output
if !exclude(x)
push!(cache, x)
push!(output, x)
foreach(y -> fcollect2(y; cache=cache, output=output, exclude=exclude), Functors.children(x))
end
return output
end

"""
fetch_weights(m) -> Vector{Array}
Returns the weights of a model by using `Functors.children` to recurse
through the model, keeping any arrays found. The `@functor` macro defines
`Functors.children` automatically so that should be sufficient to support
custom types.
Note that this function does not copy the results, so that e.g. mutating `fetch_weights(m)[1]` modifies the model.
"""
fetch_weights(m) = filter(x -> x isa Array, fcollect2(m))

"""
load_weights!(m, xs)
Load weights `xs` into the model `m`, using [`fetch_weights`](@ref).
"""
function load_weights!(m, xs)
model_weights = fetch_weights(m)
if length(model_weights) != length(xs)
throw(ArgumentError("Number of weights given ($(length(xs))) does not match number of weights model expects ($(length(model_weights)))"))
end
for (i, (p, x)) in enumerate(zip(model_weights, xs))
if size(p) != size(x)
throw(ArgumentError("For the $(i)th weight expected param size $(size(p)), got $(size(x))"))
end
copyto!(p, x)
end
return nothing
end

load_weights!(m, xs::Weights) = load_weights!(m, collect(xs))
Loading

2 comments on commit cd965c8

@ericphanson
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/40502

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.1 -m "<description of version>" cd965c8c5fea856d282fdf4a32dd4e6ad5bc55d7
git push origin v0.1.1

Please sign in to comment.