diff --git a/Project.toml b/Project.toml index ee58684..0608290 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LegolasFlux" uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" authors = ["Beacon Biosignals, Inc."] -version = "0.1.8" +version = "0.1.9" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" @@ -12,7 +12,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Arrow = "1, 2" Flux = "0.12, 0.13" -Functors = "0.2.6, 0.3" +Functors = "0.2.6, 0.3, 0.4" Legolas = "0.4" Tables = "1" julia = "1.6" diff --git a/examples/digits.jl b/examples/digits.jl index 5387a81..a39dd2e 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -33,6 +33,7 @@ Flux.@functor DigitsModel (chain,) function DigitsModel(config::DigitsConfig=DigitsConfig()) dropout_rate = config.dropout_rate Random.seed!(config.seed) + D = Dense(10, 10) chain = Chain(Dropout(dropout_rate), Conv((3, 3), 1 => 32, relu), BatchNorm(32, relu), @@ -47,6 +48,8 @@ function DigitsModel(config::DigitsConfig=DigitsConfig()) x -> reshape(x, :, size(x, 4)), Dropout(dropout_rate), Dense(90, 10), + D, + D, # test weight-sharing softmax) return DigitsModel(chain, config) end @@ -118,6 +121,7 @@ function train_model!(m; N=N_train) end m = DigitsModel() +@test m.chain[end-2] === m.chain[end-1] # test weight-sharing # increase N to actually train more than a tiny amount acc = train_model!(m; N=10) @@ -152,8 +156,10 @@ roundtripped_model = DigitsModel(roundtripped) output3 = roundtripped_model(input) @test output3 isa Matrix{Float32} +@test roundtripped_model.chain[end-2] === roundtripped_model.chain[end-1] # test weight-sharing + # Here, we've hardcoded the results at the time of serialization. # This lets us check that the model we've saved gives the same answers now as it did then. # It is OK to update this test w/ a new reference if the answers are *supposed* to change for some reason. Just make sure that is the case. @test output3 ≈ - Float32[0.09915658; 0.100575574; 0.101189725; 0.10078623; 0.09939819; 0.099650174; 0.1013182; 0.09952383; 0.0991391; 0.09926238;;] + Float32[0.096030906; 0.105671346; 0.09510324; 0.117868274; 0.112540945; 0.08980863; 0.062402092; 0.09776583; 0.11317684; 0.109631866;;] diff --git a/examples/test.digits-model.arrow b/examples/test.digits-model.arrow index 43cc445..49e2b7b 100644 Binary files a/examples/test.digits-model.arrow and b/examples/test.digits-model.arrow differ