Skip to content

Commit

Permalink
Add serialization reference test & bump Legolas compat (#22)
Browse files Browse the repository at this point in the history
* add serialization test, bump compat

* format example

* Update examples/digits.jl

* Update examples/digits.jl

Co-authored-by: Dave Kleinschmidt <[email protected]>

* Bump to Legolas v0.4

Co-authored-by: Dave Kleinschmidt <[email protected]>
  • Loading branch information
ericphanson and kleinschmidt authored Nov 18, 2022
1 parent c76ddb0 commit e3fd92e
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 23 deletions.
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
style = "yas"
13 changes: 3 additions & 10 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,12 @@ jobs:
strategy:
fail-fast: false
matrix:
legolas-version:
- '0.2'
- '0.3'
flux-version:
- '0.12'
- '0.13'
version:
- '1' # current release
include:
- version: '1.5' # earliest supported version
legolas-version: '0.2'
flux-version: '0.12'
- '1.6' # earliest supported version
steps:
- uses: actions/checkout@v2
with:
Expand All @@ -43,12 +37,11 @@ jobs:
with:
version: ${{ matrix.version }}
arch: x64
- name: "Install Legolas and Flux"
- name: "Install Flux"
shell: julia --color=yes --project=. {0}
run: |
using Pkg
Pkg.add([Pkg.PackageSpec(; name="Legolas", version="${{ matrix.legolas-version }}"),
Pkg.PackageSpec(; name="Flux", version="${{ matrix.flux-version }}")])
Pkg.add([Pkg.PackageSpec(; name="Flux", version="${{ matrix.flux-version }}")])
- uses: actions/cache@v2
with:
path: ~/.julia/artifacts
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LegolasFlux"
uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4"
authors = ["Beacon Biosignals, Inc."]
version = "0.1.7"
version = "0.1.8"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
Expand All @@ -13,9 +13,9 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Arrow = "1, 2"
Flux = "0.12, 0.13"
Functors = "0.2.6, 0.3"
Legolas = "0.1, 0.2, 0.3"
Legolas = "0.4"
Tables = "1"
julia = "1.5"
julia = "1.6"

[extras]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
39 changes: 29 additions & 10 deletions examples/digits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ using StableRNGs
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Legolas, LegolasFlux
using Tables

# This should store all the information needed
# to construct the model.
Base.@kwdef struct DigitsConfig
seed::Int = 5
dropout_rate::Float32 = 0f1
end
const DigitsConfig = Legolas.@row("digits-config@1",
seed::Int = 5,
dropout_rate::Float32 = 0.0f1,)

# Here's our model object itself, just a `DigitsConfig` and
# a `chain`. We keep the config around so it's easy to save out
Expand Down Expand Up @@ -57,9 +57,9 @@ end
# Here, we define a schema extension of the `legolas-flux.model` schema.
# We add our `DigitsConfig` object, as well as the epoch and accuracy.
const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1",
config::DigitsConfig,
epoch::Union{Missing, Int},
accuracy::Union{Missing, Float32})
config::DigitsConfig = DigitsConfig(config),
epoch::Union{Missing,Int},
accuracy::Union{Missing,Float32})

# Construct a `DigitsRow` from a model by collecting the weights.
# This can then be saved with e.g. `LegolasFlux.write_model_row`.
Expand All @@ -76,7 +76,6 @@ function DigitsModel(row)
return m
end


# Increase to get more training/test data
N_train = 1_000
N_test = 50
Expand Down Expand Up @@ -109,11 +108,12 @@ function accuracy(m, x, y)
return val
end

function train_model!(m; N = N_train)
function train_model!(m; N=N_train)
loss = (x, y) -> crossentropy(m(x), y)
opt = ADAM()
evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5)
Flux.@epochs 1 Flux.train!(loss, Flux.params(m), Iterators.take(train, N), opt; cb=evalcb)
Flux.@epochs 1 Flux.train!(loss, Flux.params(m), Iterators.take(train, N), opt;
cb=evalcb)
return accuracy(m, tX, tY)
end

Expand All @@ -138,3 +138,22 @@ testmode!(m2)
output2 = m2(input)

@test output output2

path = joinpath(pkgdir(LegolasFlux), "examples", "test.digits-model.arrow")
# The saved weights in this repo were generated by running the command:
# Legolas.write(path, [row], Legolas.Schema("digits.model@1"))
# We don't run this every time, since we want to test that we can continue to deserialize previously saved out weights.
table = Legolas.read(path)
roundtripped = DigitsRow(only(Tables.rows(table)))
@test roundtripped isa DigitsRow
@test roundtripped.config isa DigitsConfig

roundtripped_model = DigitsModel(roundtripped)
output3 = roundtripped_model(input)
@test output3 isa Matrix{Float32}

# Here, we've hardcoded the results at the time of serialization.
# This lets us check that the model we've saved gives the same answers now as it did then.
# It is OK to update this test w/ a new reference if the answers are *supposed* to change for some reason. Just make sure that is the case.
@test output3
Float32[0.09915658; 0.100575574; 0.101189725; 0.10078623; 0.09939819; 0.099650174; 0.1013182; 0.09952383; 0.0991391; 0.09926238;;]
Binary file added examples/test.digits-model.arrow
Binary file not shown.

2 comments on commit e3fd92e

@kleinschmidt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/72481

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.8 -m "<description of version>" e3fd92ec4f4ebfafe13e4c94495f9d555e33b2f8
git push origin v0.1.8

Please sign in to comment.