From cd965c8c5fea856d282fdf4a32dd4e6ad5bc55d7 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Thu, 8 Jul 2021 14:42:57 +0200 Subject: [PATCH] Workaround Flux#1027 (#4) * add example (broken) * add Flux workarounds * add tests * oops! actually use BatchNorm in example * add comments to make example more helpful * Add Random dependency to test * add comment to address review * switch to `fcollect`-based implementation * rename file, clean up a little * update comment * simplify API: `collect` not needed * update compat * add Flux compat * add link to issue * Apply suggestions from code review Co-authored-by: Alex Arslan * better errors * rename `loadweights!` to `load_weights!` * use stablerng * rename `weights` function to `fetch_weights` * add terminology note Co-authored-by: Hannah Robertson Co-authored-by: Alex Arslan --- .gitignore | 2 +- Project.toml | 10 ++- README.md | 18 +++--- examples/Project.toml | 7 +++ examples/digits.jl | 140 ++++++++++++++++++++++++++++++++++++++++++ src/LegolasFlux.jl | 4 ++ src/functors.jl | 45 ++++++++++++++ test/runtests.jl | 65 ++++++++++++++++++-- 8 files changed, 276 insertions(+), 15 deletions(-) create mode 100644 examples/Project.toml create mode 100644 examples/digits.jl create mode 100644 src/functors.jl diff --git a/.gitignore b/.gitignore index b067edd..ba39cc5 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -/Manifest.toml +Manifest.toml diff --git a/Project.toml b/Project.toml index aeae137..4cb6429 100644 --- a/Project.toml +++ b/Project.toml @@ -1,22 +1,28 @@ name = "LegolasFlux" uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" authors = ["Beacon Biosignals, Inc."] -version = "0.1.0" +version = "0.1.1" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Arrow = "1" +Flux = "0.12" +Functors = "0.2.1" Legolas = "0.1, 0.2" Tables = "1" julia = "1.5" [extras] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Flux"] +test = ["Test", "Flux", "StableRNGs", "Statistics", "Random"] diff --git a/README.md b/README.md index 4b64369..630cc19 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,11 @@ [![codecov](https://codecov.io/gh/beacon-biosignals/LegolasFlux.jl/branch/main/graph/badge.svg?token=NHYUL22HCC)](https://codecov.io/gh/beacon-biosignals/LegolasFlux.jl) LegolasFlux provides some simple functionality to use [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/)'s -extensible Arrow schemas as means to serialize Flux models using Flux's `params` and `loadparams!`. +extensible Arrow schemas as means to serialize Flux models similarly to using Flux's `params` and `loadparams!` +(instead, we export similar functions `fetch_weights` and `load_weights!` which handle layers like `BatchNorm` correctly for this purpose). The aim is to serialize only the numeric weights, *not* the code defining the model. This is a very different approach -from e.g. BSON.jl, and hopefully much more robust. +from e.g. BSON.jl, and hopefully much more robust. Note that in this package, we use `weights` to refer to the numeric arrays that are modified over the course of training a model; that includes biases as well as means and variances in e.g. BatchNorms (but not e.g. configuration settings). With this approach, however, if you change the code such that the weights are no longer valid (e.g. add a layer), you will not be able to load back the same model. @@ -28,16 +29,16 @@ my_model = make_my_model() using LegolasFlux # We can save whatever other columns we'd like to as well as the `weights`. -model_row = ModelRow(; weights = collect(params(cpu(my_model))), architecture_version = 1, loss = 0.5) +model_row = ModelRow(; weights = fetch_weights(cpu(my_model)), + architecture_version=1, loss=0.5) write_model_row("my_model.model.arrow", model_row) # Great! Later on, we want to re-load our model weights. fresh_model = make_my_model() model_row = read_model_row("my_model.model.arrow") -Flux.loadparams!(fresh_model, collect(model_row.weights)) -# Now our params have been loaded back into `fresh_model`. -# Note we needed to `collect` the weights before we use them. +load_weights!(fresh_model, model_row.weights) +# Now our weights have been loaded back into `fresh_model`. # We can also check out our other columns: model_row.loss # 0.5 @@ -47,6 +48,8 @@ model_row.loss # 0.5 We can make use of the `architecture_version` column to specify a version number for the architectures, in order to keep track of for which architectures the weights are valid for. +See [examples/digits.jl](examples/digits.jl) for a larger example. + ## `LegolasFlux.ModelRow` A `LegolasFlux.ModelRow` is the central object of LegolasFlux. It acts as a Tables.jl-compatible row that can store the weights @@ -78,4 +81,5 @@ one might name files produced by this row as e.g. `training_run.digits.model.arr Note in this example the schema is called `digits.model` instead of just say `digits`, since the package Digits might want to create other Legolas schemas as well at some point. -Check out the [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/) repo to see more about how its extensible schema system works. +Check out the [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/) repo to see more about how its extensible schema system works, +and the example at [examples/digits.jl](examples/digits.jl). diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 0000000..142443c --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,7 @@ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" +LegolasFlux = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/examples/digits.jl b/examples/digits.jl new file mode 100644 index 0000000..611813f --- /dev/null +++ b/examples/digits.jl @@ -0,0 +1,140 @@ +# Model modified from +# https://discourse.julialang.org/t/how-to-drop-the-dropout-layers-in-flux-jl-when-assessing-model-performance/19924 + +using Flux, Statistics, Random, Test +# Uncomment to use MNIST data +# using MLDatasets: MNIST +using StableRNGs +using Flux: onehotbatch, onecold, crossentropy, throttle +using Base.Iterators: repeated, partition +using Legolas, LegolasFlux + +# This should store all the information needed +# to construct the model. +Base.@kwdef struct DigitsConfig + seed::Int = 5 + dropout_rate::Float32 = 0f1 +end + +# Here's our model object itself, just a `DigitsConfig` and +# a `chain`. We keep the config around so it's easy to save out +# later. +struct DigitsModel + chain::Chain + config::DigitsConfig +end + +# Ensure Flux can recurse into our model to find params etc +Flux.@functor DigitsModel (chain,) + +# Construct the actual model from a config object. This is the only +# constructor that should be used, to ensure the model is created just +# from the config object alone. +function DigitsModel(config::DigitsConfig=DigitsConfig()) + dropout_rate = config.dropout_rate + Random.seed!(config.seed) + chain = Chain(Dropout(dropout_rate), + Conv((3, 3), 1 => 32, relu), + BatchNorm(32, relu), + MaxPool((2, 2)), + Dropout(dropout_rate), + Conv((3, 3), 32 => 16, relu), + Dropout(dropout_rate), + MaxPool((2, 2)), + Dropout(dropout_rate), + Conv((3, 3), 16 => 10, relu), + Dropout(dropout_rate), + x -> reshape(x, :, size(x, 4)), + Dropout(dropout_rate), + Dense(90, 10), + softmax) + return DigitsModel(chain, config) +end + +# Our model acts on input just by applying the chain. +(m::DigitsModel)(x) = m.chain(x) + +# Here, we define a schema extension of the `legolas-flux.model` schema. +# We add our `DigitsConfig` object, as well as the epoch and accuracy. +const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1", + config::DigitsConfig, + epoch::Union{Missing, Int}, + accuracy::Union{Missing, Float32}) + +# Construct a `DigitsRow` from a model by collecting the weights. +# This can then be saved with e.g. `LegolasFlux.write_model_row`. +function DigitsRow(model::DigitsModel; epoch=missing, accuracy=missing) + return DigitsRow(; weights=fetch_weights(model), model.config, epoch, accuracy) +end + +# Construct a `DigitsModel` from a row satisfying the `DigitsRow` schema, +# i.e. one with a `weights` and `config::DigitsConfig`. +# This could be the result of `LegolasFlux.read_model_row`. +function DigitsModel(row) + m = DigitsModel(row.config) + load_weights!(m, row.weights) + return m +end + + +# Increase to get more training/test data +N_train = 1_000 +N_test = 50 + +## +# to use MNIST data, uncomment these +# train_x, train_y = MNIST.traindata(Float32, 1:N_train) +# test_x, test_y = MNIST.testdata(Float32, 1:N_test) + +# Random data: +rng = StableRNG(735) +train_x = rand(rng, Float32, 28, 28, N_train) +train_y = rand(rng, 0:9, N_train) +test_x = rand(rng, Float32, 28, 28, N_test) +test_y = rand(rng, 0:9, N_test) +## + +# Partition into batches of size 32 +batch_size = 32 +train = [(reshape(train_x[:, :, I], 28, 28, 1, :), onehotbatch(train_y[I], 0:9)) + for I in partition(1:N_train, batch_size)] + +tX = reshape(test_x, 28, 28, 1, :) +tY = onehotbatch(test_y, 0:9) + +function accuracy(m, x, y) + testmode!(m) + val = mean(onecold(m(x)) .== onecold(y)) + trainmode!(m) + return val +end + +function train_model!(m; N = N_train) + loss = (x, y) -> crossentropy(m(x), y) + opt = ADAM() + evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5) + Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt; cb=evalcb) + return accuracy(m, tX, tY) +end + +m = DigitsModel() + +# increase N to actually train more than a tiny amount +acc = train_model!(m; N=10) + +# Let's serialize out the weights into a `DigitsRow`. +# We could save this here with `write_model_row`. +row = DigitsRow(m; epoch=1, accuracy=acc) + +testmode!(m) +input = tX[:, :, :, 1:1] +output = m(input) +label = tY[:, 1] + +# Let's now reconstruct the model from the `row` and check that we get +# the same outputs. +m2 = DigitsModel(row) +testmode!(m2) +output2 = m2(input) + +@test output ≈ output2 diff --git a/src/LegolasFlux.jl b/src/LegolasFlux.jl index d2152eb..616d6d2 100644 --- a/src/LegolasFlux.jl +++ b/src/LegolasFlux.jl @@ -1,11 +1,14 @@ module LegolasFlux export write_model_row, read_model_row +export fetch_weights, load_weights! using Legolas using Arrow using Arrow.ArrowTypes using Tables +using Functors +using Base: IdSet const LEGOLAS_SCHEMA = Legolas.Schema("legolas-flux.model@1") @@ -110,5 +113,6 @@ function read_model_row(io_or_path) return only(rows) end +include("functors.jl") end # module diff --git a/src/functors.jl b/src/functors.jl new file mode 100644 index 0000000..175d8e3 --- /dev/null +++ b/src/functors.jl @@ -0,0 +1,45 @@ +# Modified version of `fcollect` to use an `IdSet` cache so that +# distinct arrays whose values happen to be duplicates are each kept. +# +function fcollect2(x; output=[], cache=IdSet(), exclude=_ -> false) + x in cache && return output + if !exclude(x) + push!(cache, x) + push!(output, x) + foreach(y -> fcollect2(y; cache=cache, output=output, exclude=exclude), Functors.children(x)) + end + return output +end + +""" + fetch_weights(m) -> Vector{Array} + +Returns the weights of a model by using `Functors.children` to recurse +through the model, keeping any arrays found. The `@functor` macro defines +`Functors.children` automatically so that should be sufficient to support +custom types. + +Note that this function does not copy the results, so that e.g. mutating `fetch_weights(m)[1]` modifies the model. +""" +fetch_weights(m) = filter(x -> x isa Array, fcollect2(m)) + +""" + load_weights!(m, xs) + +Load weights `xs` into the model `m`, using [`fetch_weights`](@ref). +""" +function load_weights!(m, xs) + model_weights = fetch_weights(m) + if length(model_weights) != length(xs) + throw(ArgumentError("Number of weights given ($(length(xs))) does not match number of weights model expects ($(length(model_weights)))")) + end + for (i, (p, x)) in enumerate(zip(model_weights, xs)) + if size(p) != size(x) + throw(ArgumentError("For the $(i)th weight expected param size $(size(p)), got $(size(x))")) + end + copyto!(p, x) + end + return nothing +end + +load_weights!(m, xs::Weights) = load_weights!(m, collect(xs)) diff --git a/test/runtests.jl b/test/runtests.jl index 1032534..82ae7c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,8 @@ using Test using Flux, LegolasFlux using LegolasFlux: Weights, FlatArray, ModelRow using Arrow +using Random +using StableRNGs function make_my_model() return Chain(Dense(1, 10), Dense(10, 10), Dense(10, 1)) @@ -13,18 +15,20 @@ function test_weights() return [reshape(Float32.(1:prod(s)), s) for s in shapes] end -@testset begin +# This simple model should work with both Flux's `params/loadparams!` and +# our `weights/load_weights!`. The only difference is in layers with `!isempty(other_weights(layer))`. +@testset "using ($get_weights, $load_weights)" for (get_weights, load_weights) in [(fetch_weights, load_weights!, params, Flux.loadparams!)] my_model = make_my_model() - Flux.loadparams!(my_model, test_weights()) + load_weights(my_model, test_weights()) - model_row = ModelRow(; weights=collect(params(my_model))) + model_row = ModelRow(; weights=collect(get_weights(my_model))) write_model_row("my_model.model.arrow", model_row) fresh_model = make_my_model() model_row = read_model_row("my_model.model.arrow") weights = collect(model_row.weights) - Flux.loadparams!(fresh_model, weights) + load_weights(fresh_model, weights) @test collect(params(fresh_model)) == weights == test_weights() @@ -33,8 +37,20 @@ end rm("my_model.model.arrow") end +@testset "Errors" begin + my_model = make_my_model() + w = test_weights() + w[end] = [] + @test_throws ArgumentError load_weights!(my_model, w) + + w = test_weights() + push!(w, []) + @test_throws ArgumentError load_weights!(my_model, w) +end + @testset "`Weights`" begin - v = [rand(Int8, 5), rand(Float32, 5, 5)] + rng = StableRNG(245) + v = [rand(rng, Int8, 5), rand(rng, Float32, 5, 5)] @test Weights(v) isa Weights{Float32} @test Weights(FlatArray{Float32}.(v)) isa Weights{Float32} @test Weights(FlatArray{Float64}.(v)) isa Weights{Float64} @@ -43,3 +59,42 @@ end tbl = [(; weights = w)] @test Arrow.Table(Arrow.tobuffer(tbl)).weights[1] == w end + +@testset "`flux_workarounds`" begin + @testset "layer $layer" for layer in [BatchNorm, InstanceNorm, (c) -> GroupNorm(c, 1), c -> identity] + mk_model = () -> (Random.seed!(1); Chain(Dense(1, 10), Dense(10, 10), layer(1), Dense(10, 1))) + model = mk_model() + trainmode!(model) + x = reshape([1f0], 1, 1, 1) + for i in 1:10 + x = model(x) + end + testmode!(model) + w = fetch_weights(model) + p = collect(params(model)) + output = model(x) + + r1 = mk_model() + load_weights!(r1, w) + testmode!(r1) + + @test output ≈ r1(x) + + if layer == BatchNorm + r2 = mk_model() + Flux.loadparams!(r2, p) + testmode!(r2) + + # If this test *fails*, meaning `output ≈ r2(x)`, + # then perhaps we should revisit `load_weights!` + # and could consider switching to `Flux.loadparams`. + # See https://github.com/beacon-biosignals/LegolasFlux.jl/pull/4 + # for more. + @test_broken output ≈ r2(x) + end + end +end + +@testset "Example" begin + include("../examples/digits.jl") +end