diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 0000000..857c3ae --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "yas" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 07a2eec..f655384 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,18 +23,12 @@ jobs: strategy: fail-fast: false matrix: - legolas-version: - - '0.2' - - '0.3' flux-version: - '0.12' - '0.13' version: - '1' # current release - include: - - version: '1.5' # earliest supported version - legolas-version: '0.2' - flux-version: '0.12' + - '1.6' # earliest supported version steps: - uses: actions/checkout@v2 with: @@ -43,12 +37,11 @@ jobs: with: version: ${{ matrix.version }} arch: x64 - - name: "Install Legolas and Flux" + - name: "Install Flux" shell: julia --color=yes --project=. {0} run: | using Pkg - Pkg.add([Pkg.PackageSpec(; name="Legolas", version="${{ matrix.legolas-version }}"), - Pkg.PackageSpec(; name="Flux", version="${{ matrix.flux-version }}")]) + Pkg.add([Pkg.PackageSpec(; name="Flux", version="${{ matrix.flux-version }}")]) - uses: actions/cache@v2 with: path: ~/.julia/artifacts diff --git a/Project.toml b/Project.toml index 7291ace..ee58684 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LegolasFlux" uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" authors = ["Beacon Biosignals, Inc."] -version = "0.1.7" +version = "0.1.8" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" @@ -13,9 +13,9 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Arrow = "1, 2" Flux = "0.12, 0.13" Functors = "0.2.6, 0.3" -Legolas = "0.1, 0.2, 0.3" +Legolas = "0.4" Tables = "1" -julia = "1.5" +julia = "1.6" [extras] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/examples/digits.jl b/examples/digits.jl index 0bbc08d..5387a81 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -8,13 +8,13 @@ using StableRNGs using Flux: onehotbatch, onecold, crossentropy, throttle using Base.Iterators: repeated, partition using Legolas, LegolasFlux +using Tables # This should store all the information needed # to construct the model. -Base.@kwdef struct DigitsConfig - seed::Int = 5 - dropout_rate::Float32 = 0f1 -end +const DigitsConfig = Legolas.@row("digits-config@1", + seed::Int = 5, + dropout_rate::Float32 = 0.0f1,) # Here's our model object itself, just a `DigitsConfig` and # a `chain`. We keep the config around so it's easy to save out @@ -57,9 +57,9 @@ end # Here, we define a schema extension of the `legolas-flux.model` schema. # We add our `DigitsConfig` object, as well as the epoch and accuracy. const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1", - config::DigitsConfig, - epoch::Union{Missing, Int}, - accuracy::Union{Missing, Float32}) + config::DigitsConfig = DigitsConfig(config), + epoch::Union{Missing,Int}, + accuracy::Union{Missing,Float32}) # Construct a `DigitsRow` from a model by collecting the weights. # This can then be saved with e.g. `LegolasFlux.write_model_row`. @@ -76,7 +76,6 @@ function DigitsModel(row) return m end - # Increase to get more training/test data N_train = 1_000 N_test = 50 @@ -109,11 +108,12 @@ function accuracy(m, x, y) return val end -function train_model!(m; N = N_train) +function train_model!(m; N=N_train) loss = (x, y) -> crossentropy(m(x), y) opt = ADAM() evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5) - Flux.@epochs 1 Flux.train!(loss, Flux.params(m), Iterators.take(train, N), opt; cb=evalcb) + Flux.@epochs 1 Flux.train!(loss, Flux.params(m), Iterators.take(train, N), opt; + cb=evalcb) return accuracy(m, tX, tY) end @@ -138,3 +138,22 @@ testmode!(m2) output2 = m2(input) @test output ≈ output2 + +path = joinpath(pkgdir(LegolasFlux), "examples", "test.digits-model.arrow") +# The saved weights in this repo were generated by running the command: +# Legolas.write(path, [row], Legolas.Schema("digits.model@1")) +# We don't run this every time, since we want to test that we can continue to deserialize previously saved out weights. +table = Legolas.read(path) +roundtripped = DigitsRow(only(Tables.rows(table))) +@test roundtripped isa DigitsRow +@test roundtripped.config isa DigitsConfig + +roundtripped_model = DigitsModel(roundtripped) +output3 = roundtripped_model(input) +@test output3 isa Matrix{Float32} + +# Here, we've hardcoded the results at the time of serialization. +# This lets us check that the model we've saved gives the same answers now as it did then. +# It is OK to update this test w/ a new reference if the answers are *supposed* to change for some reason. Just make sure that is the case. +@test output3 ≈ + Float32[0.09915658; 0.100575574; 0.101189725; 0.10078623; 0.09939819; 0.099650174; 0.1013182; 0.09952383; 0.0991391; 0.09926238;;] diff --git a/examples/test.digits-model.arrow b/examples/test.digits-model.arrow new file mode 100644 index 0000000..43cc445 Binary files /dev/null and b/examples/test.digits-model.arrow differ