-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add example (broken) * add Flux workarounds * add tests * oops! actually use BatchNorm in example * add comments to make example more helpful * Add Random dependency to test * add comment to address review * switch to `fcollect`-based implementation * rename file, clean up a little * update comment * simplify API: `collect` not needed * update compat * add Flux compat * add link to issue * Apply suggestions from code review Co-authored-by: Alex Arslan <[email protected]> * better errors * rename `loadweights!` to `load_weights!` * use stablerng * rename `weights` function to `fetch_weights` * add terminology note Co-authored-by: Hannah Robertson <[email protected]> Co-authored-by: Alex Arslan <[email protected]>
- Loading branch information
1 parent
002298b
commit cd965c8
Showing
8 changed files
with
276 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
/Manifest.toml | ||
Manifest.toml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,28 @@ | ||
name = "LegolasFlux" | ||
uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" | ||
authors = ["Beacon Biosignals, Inc."] | ||
version = "0.1.0" | ||
version = "0.1.1" | ||
|
||
[deps] | ||
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" | ||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | ||
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" | ||
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" | ||
|
||
[compat] | ||
Arrow = "1" | ||
Flux = "0.12" | ||
Functors = "0.2.1" | ||
Legolas = "0.1, 0.2" | ||
Tables = "1" | ||
julia = "1.5" | ||
|
||
[extras] | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[targets] | ||
test = ["Test", "Flux"] | ||
test = ["Test", "Flux", "StableRNGs", "Statistics", "Random"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[deps] | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" | ||
LegolasFlux = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" | ||
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" | ||
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Model modified from | ||
# https://discourse.julialang.org/t/how-to-drop-the-dropout-layers-in-flux-jl-when-assessing-model-performance/19924 | ||
|
||
using Flux, Statistics, Random, Test | ||
# Uncomment to use MNIST data | ||
# using MLDatasets: MNIST | ||
using StableRNGs | ||
using Flux: onehotbatch, onecold, crossentropy, throttle | ||
using Base.Iterators: repeated, partition | ||
using Legolas, LegolasFlux | ||
|
||
# This should store all the information needed | ||
# to construct the model. | ||
Base.@kwdef struct DigitsConfig | ||
seed::Int = 5 | ||
dropout_rate::Float32 = 0f1 | ||
end | ||
|
||
# Here's our model object itself, just a `DigitsConfig` and | ||
# a `chain`. We keep the config around so it's easy to save out | ||
# later. | ||
struct DigitsModel | ||
chain::Chain | ||
config::DigitsConfig | ||
end | ||
|
||
# Ensure Flux can recurse into our model to find params etc | ||
Flux.@functor DigitsModel (chain,) | ||
|
||
# Construct the actual model from a config object. This is the only | ||
# constructor that should be used, to ensure the model is created just | ||
# from the config object alone. | ||
function DigitsModel(config::DigitsConfig=DigitsConfig()) | ||
dropout_rate = config.dropout_rate | ||
Random.seed!(config.seed) | ||
chain = Chain(Dropout(dropout_rate), | ||
Conv((3, 3), 1 => 32, relu), | ||
BatchNorm(32, relu), | ||
MaxPool((2, 2)), | ||
Dropout(dropout_rate), | ||
Conv((3, 3), 32 => 16, relu), | ||
Dropout(dropout_rate), | ||
MaxPool((2, 2)), | ||
Dropout(dropout_rate), | ||
Conv((3, 3), 16 => 10, relu), | ||
Dropout(dropout_rate), | ||
x -> reshape(x, :, size(x, 4)), | ||
Dropout(dropout_rate), | ||
Dense(90, 10), | ||
softmax) | ||
return DigitsModel(chain, config) | ||
end | ||
|
||
# Our model acts on input just by applying the chain. | ||
(m::DigitsModel)(x) = m.chain(x) | ||
|
||
# Here, we define a schema extension of the `legolas-flux.model` schema. | ||
# We add our `DigitsConfig` object, as well as the epoch and accuracy. | ||
const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1", | ||
config::DigitsConfig, | ||
epoch::Union{Missing, Int}, | ||
accuracy::Union{Missing, Float32}) | ||
|
||
# Construct a `DigitsRow` from a model by collecting the weights. | ||
# This can then be saved with e.g. `LegolasFlux.write_model_row`. | ||
function DigitsRow(model::DigitsModel; epoch=missing, accuracy=missing) | ||
return DigitsRow(; weights=fetch_weights(model), model.config, epoch, accuracy) | ||
end | ||
|
||
# Construct a `DigitsModel` from a row satisfying the `DigitsRow` schema, | ||
# i.e. one with a `weights` and `config::DigitsConfig`. | ||
# This could be the result of `LegolasFlux.read_model_row`. | ||
function DigitsModel(row) | ||
m = DigitsModel(row.config) | ||
load_weights!(m, row.weights) | ||
return m | ||
end | ||
|
||
|
||
# Increase to get more training/test data | ||
N_train = 1_000 | ||
N_test = 50 | ||
|
||
## | ||
# to use MNIST data, uncomment these | ||
# train_x, train_y = MNIST.traindata(Float32, 1:N_train) | ||
# test_x, test_y = MNIST.testdata(Float32, 1:N_test) | ||
|
||
# Random data: | ||
rng = StableRNG(735) | ||
train_x = rand(rng, Float32, 28, 28, N_train) | ||
train_y = rand(rng, 0:9, N_train) | ||
test_x = rand(rng, Float32, 28, 28, N_test) | ||
test_y = rand(rng, 0:9, N_test) | ||
## | ||
|
||
# Partition into batches of size 32 | ||
batch_size = 32 | ||
train = [(reshape(train_x[:, :, I], 28, 28, 1, :), onehotbatch(train_y[I], 0:9)) | ||
for I in partition(1:N_train, batch_size)] | ||
|
||
tX = reshape(test_x, 28, 28, 1, :) | ||
tY = onehotbatch(test_y, 0:9) | ||
|
||
function accuracy(m, x, y) | ||
testmode!(m) | ||
val = mean(onecold(m(x)) .== onecold(y)) | ||
trainmode!(m) | ||
return val | ||
end | ||
|
||
function train_model!(m; N = N_train) | ||
loss = (x, y) -> crossentropy(m(x), y) | ||
opt = ADAM() | ||
evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5) | ||
Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt; cb=evalcb) | ||
return accuracy(m, tX, tY) | ||
end | ||
|
||
m = DigitsModel() | ||
|
||
# increase N to actually train more than a tiny amount | ||
acc = train_model!(m; N=10) | ||
|
||
# Let's serialize out the weights into a `DigitsRow`. | ||
# We could save this here with `write_model_row`. | ||
row = DigitsRow(m; epoch=1, accuracy=acc) | ||
|
||
testmode!(m) | ||
input = tX[:, :, :, 1:1] | ||
output = m(input) | ||
label = tY[:, 1] | ||
|
||
# Let's now reconstruct the model from the `row` and check that we get | ||
# the same outputs. | ||
m2 = DigitsModel(row) | ||
testmode!(m2) | ||
output2 = m2(input) | ||
|
||
@test output ≈ output2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Modified version of `fcollect` to use an `IdSet` cache so that | ||
# distinct arrays whose values happen to be duplicates are each kept. | ||
# <https://github.com/FluxML/Functors.jl/issues/16> | ||
function fcollect2(x; output=[], cache=IdSet(), exclude=_ -> false) | ||
x in cache && return output | ||
if !exclude(x) | ||
push!(cache, x) | ||
push!(output, x) | ||
foreach(y -> fcollect2(y; cache=cache, output=output, exclude=exclude), Functors.children(x)) | ||
end | ||
return output | ||
end | ||
|
||
""" | ||
fetch_weights(m) -> Vector{Array} | ||
Returns the weights of a model by using `Functors.children` to recurse | ||
through the model, keeping any arrays found. The `@functor` macro defines | ||
`Functors.children` automatically so that should be sufficient to support | ||
custom types. | ||
Note that this function does not copy the results, so that e.g. mutating `fetch_weights(m)[1]` modifies the model. | ||
""" | ||
fetch_weights(m) = filter(x -> x isa Array, fcollect2(m)) | ||
|
||
""" | ||
load_weights!(m, xs) | ||
Load weights `xs` into the model `m`, using [`fetch_weights`](@ref). | ||
""" | ||
function load_weights!(m, xs) | ||
model_weights = fetch_weights(m) | ||
if length(model_weights) != length(xs) | ||
throw(ArgumentError("Number of weights given ($(length(xs))) does not match number of weights model expects ($(length(model_weights)))")) | ||
end | ||
for (i, (p, x)) in enumerate(zip(model_weights, xs)) | ||
if size(p) != size(x) | ||
throw(ArgumentError("For the $(i)th weight expected param size $(size(p)), got $(size(x))")) | ||
end | ||
copyto!(p, x) | ||
end | ||
return nothing | ||
end | ||
|
||
load_weights!(m, xs::Weights) = load_weights!(m, collect(xs)) |
Oops, something went wrong.
cd965c8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
cd965c8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/40502
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: