From fd9167e410e941d95d27765cd03c8f04bbb784d3 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Tue, 6 Jul 2021 18:24:43 +0200 Subject: [PATCH 01/20] add example (broken) --- .gitignore | 2 +- Project.toml | 5 +- README.md | 5 +- examples/Project.toml | 6 +++ examples/digits.jl | 104 ++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++ 6 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 examples/Project.toml create mode 100644 examples/digits.jl diff --git a/.gitignore b/.gitignore index b067edd..ba39cc5 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -/Manifest.toml +Manifest.toml diff --git a/Project.toml b/Project.toml index aeae137..be73fee 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,10 @@ julia = "1.5" [extras] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Flux"] +test = ["Test", "Flux", "MLDatasets", "Statistics", "Random"] diff --git a/README.md b/README.md index 4b64369..7aeee1c 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,8 @@ model_row.loss # 0.5 We can make use of the `architecture_version` column to specify a version number for the architectures, in order to keep track of for which architectures the weights are valid for. +See [examples/digits.jl](examples/digits.jl) for a larger example. + ## `LegolasFlux.ModelRow` A `LegolasFlux.ModelRow` is the central object of LegolasFlux. It acts as a Tables.jl-compatible row that can store the weights @@ -78,4 +80,5 @@ one might name files produced by this row as e.g. `training_run.digits.model.arr Note in this example the schema is called `digits.model` instead of just say `digits`, since the package Digits might want to create other Legolas schemas as well at some point. -Check out the [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/) repo to see more about how its extensible schema system works. +Check out the [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/) repo to see more about how its extensible schema system works, +and the example at [examples/digits.jl](examples/digits.jl). diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 0000000..131386c --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,6 @@ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" +LegolasFlux = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/examples/digits.jl b/examples/digits.jl new file mode 100644 index 0000000..2e696d5 --- /dev/null +++ b/examples/digits.jl @@ -0,0 +1,104 @@ +# modified from +# https://discourse.julialang.org/t/how-to-drop-the-dropout-layers-in-flux-jl-when-assessing-model-performance/19924 + +using Flux, Statistics, Random, Test +using MLDatasets: MNIST +using Flux: onehotbatch, onecold, crossentropy, throttle +using Base.Iterators: repeated, partition +using Legolas, LegolasFlux + +Base.@kwdef struct DigitsConfig + seed::Int = 5 + dropout_rate::Float32 = 0f1 +end + +struct DigitsModel + chain::Chain + config::DigitsConfig +end + +Flux.@functor DigitsModel (chain,) + +function DigitsModel(config::DigitsConfig = DigitsConfig()) + dropout_rate = config.dropout_rate + Random.seed!(config.seed) + chain = Chain( + Dropout(dropout_rate), + Conv((3, 3), 1=>32, relu), + BatchNorm(32, relu), + x -> maxpool(x, (2,2)), + Dropout(dropout_rate), + Conv((3, 3), 32=>16, relu), + Dropout(dropout_rate), + x -> maxpool(x, (2,2)), + Dropout(dropout_rate), + Conv((3, 3), 16=>10, relu), + Dropout(dropout_rate), + x -> reshape(x, :, size(x, 4)), + Dropout(dropout_rate), + Dense(90, 10), softmax) + return DigitsModel(chain, config) +end + +(m::DigitsModel)(x) = m.chain(x) + +const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1", + config::DigitsConfig, + epoch::Union{Missing, Int}, + accuracy::Union{Missing, Float32}) + +function DigitsRow(model::DigitsModel; epoch=missing, accuracy=missing) + weights = collect(params(model)) + return DigitsRow(; weights, model.config, epoch, accuracy) +end + +function DigitsModel(row) + m = DigitsModel(row.config) + Flux.loadparams!(m, collect(row.weights)) + return m +end + +N_train = 10_000 +N_test = 500 + +train_x, train_y = MNIST.traindata(Float32, 1:N_train) +test_x, test_y = MNIST.testdata(Float32, 1:N_test) + +# Partition into batches of size 32 +batch_size = 32 +train = [(reshape(train_x[:, :, I], 28, 28, 1, :), onehotbatch(train_y[I], 0:9)) + for I in partition(1:N_train, batch_size)] + +tX = reshape(test_x, 28, 28, 1, :) +tY = onehotbatch(test_y, 0:9) + +function accuracy(m, x, y) + testmode!(m) + val = mean(onecold(m(x)) .== onecold(y)) + trainmode!(m) + return val +end + +function train_model!(m) + loss = (x, y) -> crossentropy(m(x), y) + opt = ADAM() + evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5) + Flux.@epochs 1 Flux.train!(loss, params(m), train, opt, cb = evalcb) + return accuracy(m, tX, tY) +end + +m = DigitsModel() +acc = train_model!(m) + +row = DigitsRow(m; epoch=1, accuracy=acc) + +testmode!(m) +input = tX[:, :, :, 1:1] +output = m(input) +label = tY[:, 1] + +m2 = DigitsModel(row) +testmode!(m2) +output2 = m2(input) + +@test_broken output ≈ output2 diff --git a/test/runtests.jl b/test/runtests.jl index 1032534..3eeb4de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,3 +43,7 @@ end tbl = [(; weights = w)] @test Arrow.Table(Arrow.tobuffer(tbl)).weights[1] == w end + +@testset "Example" begin + include("../examples/digits.jl") +end From c80a2fac9ed4518250a012099f8eb3b42ca5bc30 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Tue, 6 Jul 2021 20:58:34 +0200 Subject: [PATCH 02/20] add Flux workarounds --- Project.toml | 8 +++--- README.md | 7 ++--- examples/Project.toml | 1 + examples/digits.jl | 48 ++++++++++++++++++++++------------ src/LegolasFlux.jl | 1 + src/flux_workarounds.jl | 57 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 8 +++--- 7 files changed, 105 insertions(+), 25 deletions(-) create mode 100644 src/flux_workarounds.jl diff --git a/Project.toml b/Project.toml index be73fee..bb34f93 100644 --- a/Project.toml +++ b/Project.toml @@ -1,25 +1,27 @@ name = "LegolasFlux" uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" authors = ["Beacon Biosignals, Inc."] -version = "0.1.0" +version = "0.1.1" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Arrow = "1" +Flux = "0.12" Legolas = "0.1, 0.2" Tables = "1" julia = "1.5" [extras] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Flux", "MLDatasets", "Statistics", "Random"] +test = ["Test", "Flux", "StableRNGs", "Statistics", "Random"] diff --git a/README.md b/README.md index 7aeee1c..de1b2de 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,8 @@ [![codecov](https://codecov.io/gh/beacon-biosignals/LegolasFlux.jl/branch/main/graph/badge.svg?token=NHYUL22HCC)](https://codecov.io/gh/beacon-biosignals/LegolasFlux.jl) LegolasFlux provides some simple functionality to use [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/)'s -extensible Arrow schemas as means to serialize Flux models using Flux's `params` and `loadparams!`. +extensible Arrow schemas as means to serialize Flux models similarly to using Flux's `params` and `loadparams!` +(instead, we export similar functions `weights` and `loadweights!` which handle layers like `BatchNorm` correctly for this purpose). The aim is to serialize only the numeric weights, *not* the code defining the model. This is a very different approach from e.g. BSON.jl, and hopefully much more robust. @@ -28,14 +29,14 @@ my_model = make_my_model() using LegolasFlux # We can save whatever other columns we'd like to as well as the `weights`. -model_row = ModelRow(; weights = collect(params(cpu(my_model))), architecture_version = 1, loss = 0.5) +model_row = ModelRow(; weights = collect(weights(cpu(my_model))), architecture_version = 1, loss = 0.5) write_model_row("my_model.model.arrow", model_row) # Great! Later on, we want to re-load our model weights. fresh_model = make_my_model() model_row = read_model_row("my_model.model.arrow") -Flux.loadparams!(fresh_model, collect(model_row.weights)) +loadweights!(fresh_model, collect(model_row.weights)) # Now our params have been loaded back into `fresh_model`. # Note we needed to `collect` the weights before we use them. diff --git a/examples/Project.toml b/examples/Project.toml index 131386c..142443c 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -3,4 +3,5 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" LegolasFlux = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/examples/digits.jl b/examples/digits.jl index 2e696d5..2d0ad8b 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -1,8 +1,10 @@ -# modified from +# Model modified from # https://discourse.julialang.org/t/how-to-drop-the-dropout-layers-in-flux-jl-when-assessing-model-performance/19924 using Flux, Statistics, Random, Test -using MLDatasets: MNIST +# Uncomment to use MNIST data +# using MLDatasets: MNIST +using StableRNGs using Flux: onehotbatch, onecold, crossentropy, throttle using Base.Iterators: repeated, partition using Legolas, LegolasFlux @@ -25,12 +27,12 @@ function DigitsModel(config::DigitsConfig = DigitsConfig()) chain = Chain( Dropout(dropout_rate), Conv((3, 3), 1=>32, relu), - BatchNorm(32, relu), - x -> maxpool(x, (2,2)), + # BatchNorm(32, relu), + MaxPool((2,2)), Dropout(dropout_rate), Conv((3, 3), 32=>16, relu), Dropout(dropout_rate), - x -> maxpool(x, (2,2)), + MaxPool((2,2)), Dropout(dropout_rate), Conv((3, 3), 16=>10, relu), Dropout(dropout_rate), @@ -48,21 +50,33 @@ const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1", accuracy::Union{Missing, Float32}) function DigitsRow(model::DigitsModel; epoch=missing, accuracy=missing) - weights = collect(params(model)) - return DigitsRow(; weights, model.config, epoch, accuracy) + w = collect(weights(model)) + return DigitsRow(; weights=w, model.config, epoch, accuracy) end function DigitsModel(row) m = DigitsModel(row.config) - Flux.loadparams!(m, collect(row.weights)) + loadweights!(m, collect(row.weights)) return m end -N_train = 10_000 -N_test = 500 -train_x, train_y = MNIST.traindata(Float32, 1:N_train) -test_x, test_y = MNIST.testdata(Float32, 1:N_test) +# Increase to get more training/test data +N_train = 1_000 +N_test = 50 + +## +# to use MNIST data, uncomment these +# train_x, train_y = MNIST.traindata(Float32, 1:N_train) +# test_x, test_y = MNIST.testdata(Float32, 1:N_test) + +# Random data: +rng = StableRNG(735) +train_x = rand(rng, Float32, 28, 28, N_train) +train_y = rand(rng, 0:9, N_train) +test_x = rand(rng, Float32, 28, 28, N_test) +test_y = rand(rng, 0:9, N_test) +## # Partition into batches of size 32 batch_size = 32 @@ -79,16 +93,18 @@ function accuracy(m, x, y) return val end -function train_model!(m) +function train_model!(m; N = N_train) loss = (x, y) -> crossentropy(m(x), y) opt = ADAM() evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5) - Flux.@epochs 1 Flux.train!(loss, params(m), train, opt, cb = evalcb) + Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt, cb = evalcb) return accuracy(m, tX, tY) end m = DigitsModel() -acc = train_model!(m) + +# increase N to actually train more than a tiny amount +acc = train_model!(m; N = 10) row = DigitsRow(m; epoch=1, accuracy=acc) @@ -101,4 +117,4 @@ m2 = DigitsModel(row) testmode!(m2) output2 = m2(input) -@test_broken output ≈ output2 +@test output ≈ output2 diff --git a/src/LegolasFlux.jl b/src/LegolasFlux.jl index d2152eb..483203c 100644 --- a/src/LegolasFlux.jl +++ b/src/LegolasFlux.jl @@ -110,5 +110,6 @@ function read_model_row(io_or_path) return only(rows) end +include("flux_workarounds.jl") end # module diff --git a/src/flux_workarounds.jl b/src/flux_workarounds.jl new file mode 100644 index 0000000..30f4c71 --- /dev/null +++ b/src/flux_workarounds.jl @@ -0,0 +1,57 @@ +using Flux: BatchNorm, InstanceNorm, GroupNorm, Params, trainable +using Base: IdSet +export weights, loadweights! + +""" + LegolasFlux.other_weights(layer) -> Vararg{Array} + +Given a layer with params that are not captured by `Flux.trainable`, produce +a tuple of arrays corresponding to these parameters (analogous to `Flux.trainable`). +""" +function other_weights end + +other_weights(layer) = () +other_weights(layer::BatchNorm) = (layer.μ, layer.σ²) +other_weights(layer::InstanceNorm) = (layer.μ, layer.σ²) +other_weights(layer::GroupNorm) = (layer.μ, layer.σ²) + +##### +##### `weights` +##### + +# The following is a copy of +# with `params` changed to `weights` and the addition of the lines +# ```julia +# for child in other_weights(x) +# weights!(p, child, seen) +# end +# ``` +# to `weights!(p::Params, x, seen = IdSet())`. + +weights!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) + +function weights!(p::Params, x, seen = IdSet()) + x in seen && return + push!(seen, x) + for child in trainable(x) + weights!(p, child, seen) + end + + for child in other_weights(x) + weights!(p, child, seen) + end +end + +function weights(m...) + ps = Params() + weights!(ps, m) + return ps +end + +function loadweights!(m, xs) + for (p, x) in zip(weights(m), xs) + size(p) == size(x) || + error("Expected param size $(size(p)), got $(size(x))") + copyto!(p, x) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 3eeb4de..ca3a500 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,18 +13,20 @@ function test_weights() return [reshape(Float32.(1:prod(s)), s) for s in shapes] end -@testset begin +# This simple model should work with both Flux's `params/loadparams!` and +# our `weights/loadweights!`. The only difference is in layers with `!isempty(other_weights(layer))`. +@testset "using ($get_weights, $load_weights)" for (get_weights, load_weights) in [(weights, loadweights!, params, Flux.loadparams!)] my_model = make_my_model() Flux.loadparams!(my_model, test_weights()) - model_row = ModelRow(; weights=collect(params(my_model))) + model_row = ModelRow(; weights=collect(get_weights(my_model))) write_model_row("my_model.model.arrow", model_row) fresh_model = make_my_model() model_row = read_model_row("my_model.model.arrow") weights = collect(model_row.weights) - Flux.loadparams!(fresh_model, weights) + load_weights(fresh_model, weights) @test collect(params(fresh_model)) == weights == test_weights() From 5683081a723aa5c34836905ed40dae7284be65a5 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Tue, 6 Jul 2021 21:21:58 +0200 Subject: [PATCH 03/20] add tests --- test/runtests.jl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index ca3a500..4e0fdf1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,6 +46,36 @@ end @test Arrow.Table(Arrow.tobuffer(tbl)).weights[1] == w end +@testset "`flux_workarounds`" begin + @testset "layer $layer" for layer in [BatchNorm, InstanceNorm, (c) -> GroupNorm(c, 1), c -> identity] + mk_model = () -> (Random.seed!(1); Chain(Dense(1, 10), Dense(10, 10), layer(1), Dense(10, 1))) + model = mk_model() + trainmode!(model) + x = reshape([1f0], 1, 1, 1) + for i = 1:10 + x = model(x) + end + testmode!(model) + w = collect(weights(model)) + p = collect(params(model)) + output = model(x) + + r1 = mk_model() + loadweights!(r1, w) + testmode!(r1) + + @test output ≈ r1(x) + + if layer == BatchNorm + r2 = mk_model() + Flux.loadparams!(r2, p) + testmode!(r2) + + @test_broken output ≈ r2(x) + end + end +end + @testset "Example" begin include("../examples/digits.jl") end From 2e43927d1617ae9340320390b521af9eacfe3e05 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Tue, 6 Jul 2021 21:52:53 +0200 Subject: [PATCH 04/20] oops! actually use BatchNorm in example --- examples/digits.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/digits.jl b/examples/digits.jl index 2d0ad8b..84c6661 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -27,7 +27,7 @@ function DigitsModel(config::DigitsConfig = DigitsConfig()) chain = Chain( Dropout(dropout_rate), Conv((3, 3), 1=>32, relu), - # BatchNorm(32, relu), + BatchNorm(32, relu), MaxPool((2,2)), Dropout(dropout_rate), Conv((3, 3), 32=>16, relu), From 2d5513f854b277c957084a8227d3097f0574fd22 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Tue, 6 Jul 2021 21:58:49 +0200 Subject: [PATCH 05/20] add comments to make example more helpful --- examples/digits.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/examples/digits.jl b/examples/digits.jl index 84c6661..90385bc 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -9,18 +9,27 @@ using Flux: onehotbatch, onecold, crossentropy, throttle using Base.Iterators: repeated, partition using Legolas, LegolasFlux +# This should store all the information needed +# to construct the model. Base.@kwdef struct DigitsConfig seed::Int = 5 dropout_rate::Float32 = 0f1 end +# Here's our model object itself, just a `DigitsConfig` and +# a `chain`. We keep the config around so it's easy to save out +# later. struct DigitsModel chain::Chain config::DigitsConfig end +# Ensure Flux can recurse into our model to find params etc Flux.@functor DigitsModel (chain,) +# Construct the actual model from a config object. This is the only +# constructor that should be used, to ensure the model is created just +# from the config object alone. function DigitsModel(config::DigitsConfig = DigitsConfig()) dropout_rate = config.dropout_rate Random.seed!(config.seed) @@ -42,18 +51,26 @@ function DigitsModel(config::DigitsConfig = DigitsConfig()) return DigitsModel(chain, config) end +# Our model acts on input just by applying the chain. (m::DigitsModel)(x) = m.chain(x) +# Here, we define a schema extension of the `legolas-flux.model` schema. +# We add our `DigitsConfig` object, as well as the epoch and accuracy. const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1", config::DigitsConfig, epoch::Union{Missing, Int}, accuracy::Union{Missing, Float32}) +# Construct a `DigitsRow` from a model by collecting the `weights`. +# This can then be saved with e.g. `LegolasFlux.write_model_row`. function DigitsRow(model::DigitsModel; epoch=missing, accuracy=missing) w = collect(weights(model)) return DigitsRow(; weights=w, model.config, epoch, accuracy) end +# Construct a `DigitsModel` from a row satisfying the `DigitsRow` schema, +# i.e. one with a `weights` and `config::DigitsConfig`. +# This could be the result of `LegolasFlux.read_model_row`. function DigitsModel(row) m = DigitsModel(row.config) loadweights!(m, collect(row.weights)) @@ -106,6 +123,8 @@ m = DigitsModel() # increase N to actually train more than a tiny amount acc = train_model!(m; N = 10) +# Let's serialize out the weights into a `DigitsRow`. +# We could save this here with `write_model_row`. row = DigitsRow(m; epoch=1, accuracy=acc) testmode!(m) @@ -113,6 +132,8 @@ input = tX[:, :, :, 1:1] output = m(input) label = tY[:, 1] +# Let's now reconstruct the model from the `row` and check that we get +# the same outputs. m2 = DigitsModel(row) testmode!(m2) output2 = m2(input) From bab45de162a2275dec7d162bcb64dedcb5839686 Mon Sep 17 00:00:00 2001 From: Hannah Robertson Date: Tue, 6 Jul 2021 16:30:51 -0400 Subject: [PATCH 06/20] Add Random dependency to test --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 4e0fdf1..310ad35 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Test using Flux, LegolasFlux using LegolasFlux: Weights, FlatArray, ModelRow using Arrow +using Random function make_my_model() return Chain(Dense(1, 10), Dense(10, 10), Dense(10, 1)) From 11d6fecb59e5f2e2eaa210864abe74ca072457a3 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 13:07:25 +0200 Subject: [PATCH 07/20] add comment to address review --- test/runtests.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 4e0fdf1..c1d54b2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -71,6 +71,10 @@ end Flux.loadparams!(r2, p) testmode!(r2) + # If this test *fails*, meaning `output ≈ r2(x)`, + # then perhaps Flux#1027 has been fixed and we can + # remove `flux_workarounds.jl`. + # See https://github.com/beacon-biosignals/LegolasFlux.jl/pull/4#issuecomment-875010030. @test_broken output ≈ r2(x) end end From 3c1677fc715978d84f92730532da058a5b4d9821 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 13:39:19 +0200 Subject: [PATCH 08/20] switch to `fcollect`-based implementation --- Project.toml | 2 +- src/LegolasFlux.jl | 3 ++ src/flux_workarounds.jl | 62 +++++++++-------------------------------- 3 files changed, 17 insertions(+), 50 deletions(-) diff --git a/Project.toml b/Project.toml index bb34f93..58e2204 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,7 @@ version = "0.1.1" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" diff --git a/src/LegolasFlux.jl b/src/LegolasFlux.jl index 483203c..a101528 100644 --- a/src/LegolasFlux.jl +++ b/src/LegolasFlux.jl @@ -1,11 +1,14 @@ module LegolasFlux export write_model_row, read_model_row +export weights, loadweights! using Legolas using Arrow using Arrow.ArrowTypes using Tables +using Functors +using Base: IdSet const LEGOLAS_SCHEMA = Legolas.Schema("legolas-flux.model@1") diff --git a/src/flux_workarounds.jl b/src/flux_workarounds.jl index 30f4c71..e90a280 100644 --- a/src/flux_workarounds.jl +++ b/src/flux_workarounds.jl @@ -1,57 +1,21 @@ -using Flux: BatchNorm, InstanceNorm, GroupNorm, Params, trainable -using Base: IdSet -export weights, loadweights! - -""" - LegolasFlux.other_weights(layer) -> Vararg{Array} - -Given a layer with params that are not captured by `Flux.trainable`, produce -a tuple of arrays corresponding to these parameters (analogous to `Flux.trainable`). -""" -function other_weights end - -other_weights(layer) = () -other_weights(layer::BatchNorm) = (layer.μ, layer.σ²) -other_weights(layer::InstanceNorm) = (layer.μ, layer.σ²) -other_weights(layer::GroupNorm) = (layer.μ, layer.σ²) - -##### -##### `weights` -##### - -# The following is a copy of -# with `params` changed to `weights` and the addition of the lines -# ```julia -# for child in other_weights(x) -# weights!(p, child, seen) -# end -# ``` -# to `weights!(p::Params, x, seen = IdSet())`. - -weights!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) - -function weights!(p::Params, x, seen = IdSet()) - x in seen && return - push!(seen, x) - for child in trainable(x) - weights!(p, child, seen) - end - - for child in other_weights(x) - weights!(p, child, seen) - end +# Modified version of `fcollect` to use an `IdSet` cache so that +# distinct arrays whose values happen to be duplicates are each kept. +function fcollect2(x; output = [], cache = Base.IdSet(), exclude = v -> false) + x in cache && return output + if !exclude(x) + push!(cache, x) + push!(output, x) + foreach(y -> fcollect2(y; cache = cache, output=output, exclude = exclude), Functors.children(x)) + end + return output end -function weights(m...) - ps = Params() - weights!(ps, m) - return ps -end +weights(m) = filter(x -> x isa Array, fcollect2(m)) function loadweights!(m, xs) - for (p, x) in zip(weights(m), xs) + for (i, (p, x)) in enumerate(zip(weights(m), xs)) size(p) == size(x) || - error("Expected param size $(size(p)), got $(size(x))") + error("Expected param size $(size(p)), got $(size(x)) for the $(i)th weight") copyto!(p, x) end end From 2ca8704617dfcdcb65d3c8fff80ed2137101b382 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 13:43:19 +0200 Subject: [PATCH 09/20] rename file, clean up a little --- src/LegolasFlux.jl | 2 +- src/flux_workarounds.jl | 21 --------------------- src/functors.jl | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 22 deletions(-) delete mode 100644 src/flux_workarounds.jl create mode 100644 src/functors.jl diff --git a/src/LegolasFlux.jl b/src/LegolasFlux.jl index a101528..336fe5f 100644 --- a/src/LegolasFlux.jl +++ b/src/LegolasFlux.jl @@ -113,6 +113,6 @@ function read_model_row(io_or_path) return only(rows) end -include("flux_workarounds.jl") +include("functors.jl") end # module diff --git a/src/flux_workarounds.jl b/src/flux_workarounds.jl deleted file mode 100644 index e90a280..0000000 --- a/src/flux_workarounds.jl +++ /dev/null @@ -1,21 +0,0 @@ -# Modified version of `fcollect` to use an `IdSet` cache so that -# distinct arrays whose values happen to be duplicates are each kept. -function fcollect2(x; output = [], cache = Base.IdSet(), exclude = v -> false) - x in cache && return output - if !exclude(x) - push!(cache, x) - push!(output, x) - foreach(y -> fcollect2(y; cache = cache, output=output, exclude = exclude), Functors.children(x)) - end - return output -end - -weights(m) = filter(x -> x isa Array, fcollect2(m)) - -function loadweights!(m, xs) - for (i, (p, x)) in enumerate(zip(weights(m), xs)) - size(p) == size(x) || - error("Expected param size $(size(p)), got $(size(x)) for the $(i)th weight") - copyto!(p, x) - end -end diff --git a/src/functors.jl b/src/functors.jl new file mode 100644 index 0000000..316f09f --- /dev/null +++ b/src/functors.jl @@ -0,0 +1,36 @@ +# Modified version of `fcollect` to use an `IdSet` cache so that +# distinct arrays whose values happen to be duplicates are each kept. +function fcollect2(x; output = [], cache = Base.IdSet(), exclude = v -> false) + x in cache && return output + if !exclude(x) + push!(cache, x) + push!(output, x) + foreach(y -> fcollect2(y; cache = cache, output=output, exclude = exclude), Functors.children(x)) + end + return output +end + +""" + weights(m) -> Vector{Array} + +Returns the weights of a model by using `Functors.children` to recurse +through the model, keeping any arrays found. The `@functor` macro defines +`Functors.children` automatically so that should be sufficient to support +custom types. +""" +weights(m) = filter(x -> x isa Array, fcollect2(m)) + +""" + loadweights!(m, xs) + +Load weights `xs` into the model `m`, using [`weights`](@ref). +""" +function loadweights!(m, xs) + for (i, (p, x)) in enumerate(zip(weights(m), xs)) + if size(p) != size(x) + error("For the $(i)th weight expected param size $(size(p)), got $(size(x))") + end + copyto!(p, x) + end + return nothing +end From a8547e637af15304a8abfeb2845a88bc34b66cfc Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 13:44:20 +0200 Subject: [PATCH 10/20] update comment --- test/runtests.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index eb6d78e..acdb913 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -73,9 +73,10 @@ end testmode!(r2) # If this test *fails*, meaning `output ≈ r2(x)`, - # then perhaps Flux#1027 has been fixed and we can - # remove `flux_workarounds.jl`. - # See https://github.com/beacon-biosignals/LegolasFlux.jl/pull/4#issuecomment-875010030. + # then perhaps we should revisit `loadweights!` + # and could consider switching to `Flux.loadparams`. + # See https://github.com/beacon-biosignals/LegolasFlux.jl/pull/4 + # for more. @test_broken output ≈ r2(x) end end From 368c78ff92e5e1c069287ae3644523762e6bfaf9 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 14:34:14 +0200 Subject: [PATCH 11/20] simplify API: `collect` not needed --- README.md | 8 ++++---- examples/digits.jl | 5 ++--- src/functors.jl | 2 ++ test/runtests.jl | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index de1b2de..1f6eeaa 100644 --- a/README.md +++ b/README.md @@ -29,16 +29,16 @@ my_model = make_my_model() using LegolasFlux # We can save whatever other columns we'd like to as well as the `weights`. -model_row = ModelRow(; weights = collect(weights(cpu(my_model))), architecture_version = 1, loss = 0.5) +model_row = ModelRow(; weights = weights(cpu(my_model)), + architecture_version = 1, loss = 0.5) write_model_row("my_model.model.arrow", model_row) # Great! Later on, we want to re-load our model weights. fresh_model = make_my_model() model_row = read_model_row("my_model.model.arrow") -loadweights!(fresh_model, collect(model_row.weights)) -# Now our params have been loaded back into `fresh_model`. -# Note we needed to `collect` the weights before we use them. +loadweights!(fresh_model, model_row.weights) +# Now our weights have been loaded back into `fresh_model`. # We can also check out our other columns: model_row.loss # 0.5 diff --git a/examples/digits.jl b/examples/digits.jl index 90385bc..bcac0cc 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -64,8 +64,7 @@ const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1", # Construct a `DigitsRow` from a model by collecting the `weights`. # This can then be saved with e.g. `LegolasFlux.write_model_row`. function DigitsRow(model::DigitsModel; epoch=missing, accuracy=missing) - w = collect(weights(model)) - return DigitsRow(; weights=w, model.config, epoch, accuracy) + return DigitsRow(; weights=weights(model), model.config, epoch, accuracy) end # Construct a `DigitsModel` from a row satisfying the `DigitsRow` schema, @@ -73,7 +72,7 @@ end # This could be the result of `LegolasFlux.read_model_row`. function DigitsModel(row) m = DigitsModel(row.config) - loadweights!(m, collect(row.weights)) + loadweights!(m, row.weights) return m end diff --git a/src/functors.jl b/src/functors.jl index 316f09f..30978a3 100644 --- a/src/functors.jl +++ b/src/functors.jl @@ -34,3 +34,5 @@ function loadweights!(m, xs) end return nothing end + +loadweights!(m, xs::Weights) = loadweights!(m, collect(xs)) diff --git a/test/runtests.jl b/test/runtests.jl index acdb913..09cdadc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,7 +57,7 @@ end x = model(x) end testmode!(model) - w = collect(weights(model)) + w = weights(model) p = collect(params(model)) output = model(x) From 543fd74b8de9001b468244523606fb0684094808 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 14:59:23 +0200 Subject: [PATCH 12/20] update compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 58e2204..2358ceb 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Arrow = "1" -Flux = "0.12" +Functors = "0.2.1" Legolas = "0.1, 0.2" Tables = "1" julia = "1.5" From 58ecd47c4e30f5654abcb4346b44b5f612cba11f Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 16:19:37 +0200 Subject: [PATCH 13/20] add Flux compat --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 2358ceb..4cb6429 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Arrow = "1" +Flux = "0.12" Functors = "0.2.1" Legolas = "0.1, 0.2" Tables = "1" From 1095d6116d22d29cee31dd674299c1a2de252e81 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 19:14:25 +0200 Subject: [PATCH 14/20] add link to issue --- src/functors.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/functors.jl b/src/functors.jl index 30978a3..7430e35 100644 --- a/src/functors.jl +++ b/src/functors.jl @@ -1,5 +1,6 @@ # Modified version of `fcollect` to use an `IdSet` cache so that # distinct arrays whose values happen to be duplicates are each kept. +# function fcollect2(x; output = [], cache = Base.IdSet(), exclude = v -> false) x in cache && return output if !exclude(x) From e1f09d2aa1067ef5dfef068e1bceab5737f2a513 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 19:29:54 +0200 Subject: [PATCH 15/20] Apply suggestions from code review Co-authored-by: Alex Arslan --- README.md | 2 +- examples/digits.jl | 36 ++++++++++++++++++------------------ src/functors.jl | 8 ++++---- test/runtests.jl | 2 +- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 1f6eeaa..f2076c2 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ using LegolasFlux # We can save whatever other columns we'd like to as well as the `weights`. model_row = ModelRow(; weights = weights(cpu(my_model)), - architecture_version = 1, loss = 0.5) + architecture_version=1, loss=0.5) write_model_row("my_model.model.arrow", model_row) # Great! Later on, we want to re-load our model weights. diff --git a/examples/digits.jl b/examples/digits.jl index bcac0cc..f0c401d 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -30,24 +30,24 @@ Flux.@functor DigitsModel (chain,) # Construct the actual model from a config object. This is the only # constructor that should be used, to ensure the model is created just # from the config object alone. -function DigitsModel(config::DigitsConfig = DigitsConfig()) +function DigitsModel(config::DigitsConfig=DigitsConfig()) dropout_rate = config.dropout_rate Random.seed!(config.seed) - chain = Chain( - Dropout(dropout_rate), - Conv((3, 3), 1=>32, relu), - BatchNorm(32, relu), - MaxPool((2,2)), - Dropout(dropout_rate), - Conv((3, 3), 32=>16, relu), - Dropout(dropout_rate), - MaxPool((2,2)), - Dropout(dropout_rate), - Conv((3, 3), 16=>10, relu), - Dropout(dropout_rate), - x -> reshape(x, :, size(x, 4)), - Dropout(dropout_rate), - Dense(90, 10), softmax) + chain = Chain(Dropout(dropout_rate), + Conv((3, 3), 1 => 32, relu), + BatchNorm(32, relu), + MaxPool((2, 2)), + Dropout(dropout_rate), + Conv((3, 3), 32 => 16, relu), + Dropout(dropout_rate), + MaxPool((2, 2)), + Dropout(dropout_rate), + Conv((3, 3), 16 => 10, relu), + Dropout(dropout_rate), + x -> reshape(x, :, size(x, 4)), + Dropout(dropout_rate), + Dense(90, 10), + softmax) return DigitsModel(chain, config) end @@ -113,14 +113,14 @@ function train_model!(m; N = N_train) loss = (x, y) -> crossentropy(m(x), y) opt = ADAM() evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5) - Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt, cb = evalcb) + Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt; cb=evalcb) return accuracy(m, tX, tY) end m = DigitsModel() # increase N to actually train more than a tiny amount -acc = train_model!(m; N = 10) +acc = train_model!(m; N=10) # Let's serialize out the weights into a `DigitsRow`. # We could save this here with `write_model_row`. diff --git a/src/functors.jl b/src/functors.jl index 7430e35..ce379cb 100644 --- a/src/functors.jl +++ b/src/functors.jl @@ -1,12 +1,12 @@ # Modified version of `fcollect` to use an `IdSet` cache so that # distinct arrays whose values happen to be duplicates are each kept. # -function fcollect2(x; output = [], cache = Base.IdSet(), exclude = v -> false) +function fcollect2(x; output=[], cache=IdSet(), exclude=_ -> false) x in cache && return output if !exclude(x) - push!(cache, x) - push!(output, x) - foreach(y -> fcollect2(y; cache = cache, output=output, exclude = exclude), Functors.children(x)) + push!(cache, x) + push!(output, x) + foreach(y -> fcollect2(y; cache=cache, output=output, exclude=exclude), Functors.children(x)) end return output end diff --git a/test/runtests.jl b/test/runtests.jl index 09cdadc..308e337 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,7 +53,7 @@ end model = mk_model() trainmode!(model) x = reshape([1f0], 1, 1, 1) - for i = 1:10 + for i in 1:10 x = model(x) end testmode!(model) From e4ae7751170d7a8e7c12bf58494862cdcac278bb Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 19:43:47 +0200 Subject: [PATCH 16/20] better errors --- src/functors.jl | 8 ++++++-- test/runtests.jl | 13 ++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/functors.jl b/src/functors.jl index ce379cb..dfbacc1 100644 --- a/src/functors.jl +++ b/src/functors.jl @@ -27,9 +27,13 @@ weights(m) = filter(x -> x isa Array, fcollect2(m)) Load weights `xs` into the model `m`, using [`weights`](@ref). """ function loadweights!(m, xs) - for (i, (p, x)) in enumerate(zip(weights(m), xs)) + model_weights = weights(m) + if length(model_weights) != length(xs) + throw(ArgumentError("Number of weights given ($(length(xs))) does not match number of weights model expects ($(length(model_weights)))")) + end + for (i, (p, x)) in enumerate(zip(model_weights, xs)) if size(p) != size(x) - error("For the $(i)th weight expected param size $(size(p)), got $(size(x))") + throw(ArgumentError("For the $(i)th weight expected param size $(size(p)), got $(size(x))")) end copyto!(p, x) end diff --git a/test/runtests.jl b/test/runtests.jl index 308e337..6382136 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,7 +18,7 @@ end # our `weights/loadweights!`. The only difference is in layers with `!isempty(other_weights(layer))`. @testset "using ($get_weights, $load_weights)" for (get_weights, load_weights) in [(weights, loadweights!, params, Flux.loadparams!)] my_model = make_my_model() - Flux.loadparams!(my_model, test_weights()) + load_weights(my_model, test_weights()) model_row = ModelRow(; weights=collect(get_weights(my_model))) write_model_row("my_model.model.arrow", model_row) @@ -36,6 +36,17 @@ end rm("my_model.model.arrow") end +@testset "Errors" begin + my_model = make_my_model() + w = test_weights() + w[end] = [] + @test_throws ArgumentError loadweights!(my_model, w) + + w = test_weights() + push!(w, []) + @test_throws ArgumentError loadweights!(my_model, w) +end + @testset "`Weights`" begin v = [rand(Int8, 5), rand(Float32, 5, 5)] @test Weights(v) isa Weights{Float32} From 402b90062e95f4faf387e7f8ce9d4580773fcf67 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 19:44:16 +0200 Subject: [PATCH 17/20] rename `loadweights!` to `load_weights!` --- README.md | 4 ++-- examples/digits.jl | 2 +- src/LegolasFlux.jl | 2 +- src/functors.jl | 6 +++--- test/runtests.jl | 12 ++++++------ 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index f2076c2..c028701 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ LegolasFlux provides some simple functionality to use [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/)'s extensible Arrow schemas as means to serialize Flux models similarly to using Flux's `params` and `loadparams!` -(instead, we export similar functions `weights` and `loadweights!` which handle layers like `BatchNorm` correctly for this purpose). +(instead, we export similar functions `weights` and `load_weights!` which handle layers like `BatchNorm` correctly for this purpose). The aim is to serialize only the numeric weights, *not* the code defining the model. This is a very different approach from e.g. BSON.jl, and hopefully much more robust. @@ -37,7 +37,7 @@ write_model_row("my_model.model.arrow", model_row) fresh_model = make_my_model() model_row = read_model_row("my_model.model.arrow") -loadweights!(fresh_model, model_row.weights) +load_weights!(fresh_model, model_row.weights) # Now our weights have been loaded back into `fresh_model`. # We can also check out our other columns: diff --git a/examples/digits.jl b/examples/digits.jl index f0c401d..09a2204 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -72,7 +72,7 @@ end # This could be the result of `LegolasFlux.read_model_row`. function DigitsModel(row) m = DigitsModel(row.config) - loadweights!(m, row.weights) + load_weights!(m, row.weights) return m end diff --git a/src/LegolasFlux.jl b/src/LegolasFlux.jl index 336fe5f..ed278d4 100644 --- a/src/LegolasFlux.jl +++ b/src/LegolasFlux.jl @@ -1,7 +1,7 @@ module LegolasFlux export write_model_row, read_model_row -export weights, loadweights! +export weights, load_weights! using Legolas using Arrow diff --git a/src/functors.jl b/src/functors.jl index dfbacc1..8d96ce6 100644 --- a/src/functors.jl +++ b/src/functors.jl @@ -22,11 +22,11 @@ custom types. weights(m) = filter(x -> x isa Array, fcollect2(m)) """ - loadweights!(m, xs) + load_weights!(m, xs) Load weights `xs` into the model `m`, using [`weights`](@ref). """ -function loadweights!(m, xs) +function load_weights!(m, xs) model_weights = weights(m) if length(model_weights) != length(xs) throw(ArgumentError("Number of weights given ($(length(xs))) does not match number of weights model expects ($(length(model_weights)))")) @@ -40,4 +40,4 @@ function loadweights!(m, xs) return nothing end -loadweights!(m, xs::Weights) = loadweights!(m, collect(xs)) +load_weights!(m, xs::Weights) = load_weights!(m, collect(xs)) diff --git a/test/runtests.jl b/test/runtests.jl index 6382136..b6a2555 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,8 +15,8 @@ function test_weights() end # This simple model should work with both Flux's `params/loadparams!` and -# our `weights/loadweights!`. The only difference is in layers with `!isempty(other_weights(layer))`. -@testset "using ($get_weights, $load_weights)" for (get_weights, load_weights) in [(weights, loadweights!, params, Flux.loadparams!)] +# our `weights/load_weights!`. The only difference is in layers with `!isempty(other_weights(layer))`. +@testset "using ($get_weights, $load_weights)" for (get_weights, load_weights) in [(weights, load_weights!, params, Flux.loadparams!)] my_model = make_my_model() load_weights(my_model, test_weights()) @@ -40,11 +40,11 @@ end my_model = make_my_model() w = test_weights() w[end] = [] - @test_throws ArgumentError loadweights!(my_model, w) + @test_throws ArgumentError load_weights!(my_model, w) w = test_weights() push!(w, []) - @test_throws ArgumentError loadweights!(my_model, w) + @test_throws ArgumentError load_weights!(my_model, w) end @testset "`Weights`" begin @@ -73,7 +73,7 @@ end output = model(x) r1 = mk_model() - loadweights!(r1, w) + load_weights!(r1, w) testmode!(r1) @test output ≈ r1(x) @@ -84,7 +84,7 @@ end testmode!(r2) # If this test *fails*, meaning `output ≈ r2(x)`, - # then perhaps we should revisit `loadweights!` + # then perhaps we should revisit `load_weights!` # and could consider switching to `Flux.loadparams`. # See https://github.com/beacon-biosignals/LegolasFlux.jl/pull/4 # for more. From f6feaf187752f9dc7910252b996723815fdaea3d Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Wed, 7 Jul 2021 19:53:40 +0200 Subject: [PATCH 18/20] use stablerng --- test/runtests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index b6a2555..c476be3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Flux, LegolasFlux using LegolasFlux: Weights, FlatArray, ModelRow using Arrow using Random +using StableRNGs function make_my_model() return Chain(Dense(1, 10), Dense(10, 10), Dense(10, 1)) @@ -48,7 +49,8 @@ end end @testset "`Weights`" begin - v = [rand(Int8, 5), rand(Float32, 5, 5)] + rng = StableRNG(245) + v = [rand(rng, Int8, 5), rand(rng, Float32, 5, 5)] @test Weights(v) isa Weights{Float32} @test Weights(FlatArray{Float32}.(v)) isa Weights{Float32} @test Weights(FlatArray{Float64}.(v)) isa Weights{Float64} From 093b59e68fbf172631444cc8e41fe694d59b4e20 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Thu, 8 Jul 2021 14:22:48 +0200 Subject: [PATCH 19/20] rename `weights` function to `fetch_weights` --- README.md | 4 ++-- examples/digits.jl | 4 ++-- src/LegolasFlux.jl | 2 +- src/functors.jl | 10 ++++++---- test/runtests.jl | 4 ++-- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index c028701..6ef141c 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ LegolasFlux provides some simple functionality to use [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/)'s extensible Arrow schemas as means to serialize Flux models similarly to using Flux's `params` and `loadparams!` -(instead, we export similar functions `weights` and `load_weights!` which handle layers like `BatchNorm` correctly for this purpose). +(instead, we export similar functions `fetch_weights` and `load_weights!` which handle layers like `BatchNorm` correctly for this purpose). The aim is to serialize only the numeric weights, *not* the code defining the model. This is a very different approach from e.g. BSON.jl, and hopefully much more robust. @@ -29,7 +29,7 @@ my_model = make_my_model() using LegolasFlux # We can save whatever other columns we'd like to as well as the `weights`. -model_row = ModelRow(; weights = weights(cpu(my_model)), +model_row = ModelRow(; weights = fetch_weights(cpu(my_model)), architecture_version=1, loss=0.5) write_model_row("my_model.model.arrow", model_row) diff --git a/examples/digits.jl b/examples/digits.jl index 09a2204..611813f 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -61,10 +61,10 @@ const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1", epoch::Union{Missing, Int}, accuracy::Union{Missing, Float32}) -# Construct a `DigitsRow` from a model by collecting the `weights`. +# Construct a `DigitsRow` from a model by collecting the weights. # This can then be saved with e.g. `LegolasFlux.write_model_row`. function DigitsRow(model::DigitsModel; epoch=missing, accuracy=missing) - return DigitsRow(; weights=weights(model), model.config, epoch, accuracy) + return DigitsRow(; weights=fetch_weights(model), model.config, epoch, accuracy) end # Construct a `DigitsModel` from a row satisfying the `DigitsRow` schema, diff --git a/src/LegolasFlux.jl b/src/LegolasFlux.jl index ed278d4..616d6d2 100644 --- a/src/LegolasFlux.jl +++ b/src/LegolasFlux.jl @@ -1,7 +1,7 @@ module LegolasFlux export write_model_row, read_model_row -export weights, load_weights! +export fetch_weights, load_weights! using Legolas using Arrow diff --git a/src/functors.jl b/src/functors.jl index 8d96ce6..175d8e3 100644 --- a/src/functors.jl +++ b/src/functors.jl @@ -12,22 +12,24 @@ function fcollect2(x; output=[], cache=IdSet(), exclude=_ -> false) end """ - weights(m) -> Vector{Array} + fetch_weights(m) -> Vector{Array} Returns the weights of a model by using `Functors.children` to recurse through the model, keeping any arrays found. The `@functor` macro defines `Functors.children` automatically so that should be sufficient to support custom types. + +Note that this function does not copy the results, so that e.g. mutating `fetch_weights(m)[1]` modifies the model. """ -weights(m) = filter(x -> x isa Array, fcollect2(m)) +fetch_weights(m) = filter(x -> x isa Array, fcollect2(m)) """ load_weights!(m, xs) -Load weights `xs` into the model `m`, using [`weights`](@ref). +Load weights `xs` into the model `m`, using [`fetch_weights`](@ref). """ function load_weights!(m, xs) - model_weights = weights(m) + model_weights = fetch_weights(m) if length(model_weights) != length(xs) throw(ArgumentError("Number of weights given ($(length(xs))) does not match number of weights model expects ($(length(model_weights)))")) end diff --git a/test/runtests.jl b/test/runtests.jl index c476be3..82ae7c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,7 @@ end # This simple model should work with both Flux's `params/loadparams!` and # our `weights/load_weights!`. The only difference is in layers with `!isempty(other_weights(layer))`. -@testset "using ($get_weights, $load_weights)" for (get_weights, load_weights) in [(weights, load_weights!, params, Flux.loadparams!)] +@testset "using ($get_weights, $load_weights)" for (get_weights, load_weights) in [(fetch_weights, load_weights!, params, Flux.loadparams!)] my_model = make_my_model() load_weights(my_model, test_weights()) @@ -70,7 +70,7 @@ end x = model(x) end testmode!(model) - w = weights(model) + w = fetch_weights(model) p = collect(params(model)) output = model(x) From c9f5b6147ec176b564b87ff46f63cc9fddda4d1e Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Thu, 8 Jul 2021 14:25:42 +0200 Subject: [PATCH 20/20] add terminology note --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6ef141c..630cc19 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ extensible Arrow schemas as means to serialize Flux models similarly to using Fl (instead, we export similar functions `fetch_weights` and `load_weights!` which handle layers like `BatchNorm` correctly for this purpose). The aim is to serialize only the numeric weights, *not* the code defining the model. This is a very different approach -from e.g. BSON.jl, and hopefully much more robust. +from e.g. BSON.jl, and hopefully much more robust. Note that in this package, we use `weights` to refer to the numeric arrays that are modified over the course of training a model; that includes biases as well as means and variances in e.g. BatchNorms (but not e.g. configuration settings). With this approach, however, if you change the code such that the weights are no longer valid (e.g. add a layer), you will not be able to load back the same model.