diff --git a/Project.toml b/Project.toml index 2d560ad..5b92b46 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LegolasFlux" uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" authors = ["Beacon Biosignals, Inc."] -version = "0.2.0" +version = "0.2.1" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" diff --git a/README.md b/README.md index 9a8b927..c8414fa 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,9 @@ a schema extension of the `legolas-flux.model` schema: using Legolas, LegolasFlux using Legolas: @schema, @version @schema "digits-model" DigitsRow -@version DigitsRowV1 begin +@version DigitsRowV1 > ModelV1 begin + # re-declare this ModelV1 field as parametric for this schema as well + weights::(<:Union{Missing,Weights}) epoch::Union{Missing, Int} accuracy::Union{Missing, Float32} commit_sha::Union{Missing, String} @@ -81,6 +83,12 @@ end Now I can use a `DigitsRowV1` much like LegolasFlux's `ModelV1`. It has the same required `weights` column and optional `architecture_version` column, as well as the additional `epoch`, `accuracy`, and `commit_sha` columns. As a naming convention, one might name files produced by this row as e.g. `training_run.digits.model.arrow`. +When writing out a `DigitsRowV1`, I'll pass the schema version like so +```julia +write_model_row(path, my_digits_row, DigitsRowV1SchemaVersion()) +``` +so that later, when I call `read_model_row` on this path, I'll get back a `DigitsRowV1` instance. + Note in this example the schema is called `digits.model` instead of just say `digits`, since the package Digits might want to create other Legolas schemas as well at some point. diff --git a/examples/Project.toml b/examples/Project.toml index 142443c..d6545d5 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,3 +5,4 @@ LegolasFlux = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" diff --git a/examples/digits.jl b/examples/digits.jl index d1e4ad9..1bb4ca0 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -9,6 +9,7 @@ using Flux: onehotbatch, onecold, crossentropy, throttle using Base.Iterators: repeated, partition using Legolas, LegolasFlux using Legolas: @schema, @version +using LegolasFlux: Weights using Tables # This should store all the information needed @@ -163,10 +164,9 @@ output2 = m2(input) path = joinpath(pkgdir(LegolasFlux), "examples", "test.digits-model.arrow") # The saved weights in this repo were generated by running the command: -# Legolas.write(path, [row], Legolas.Schema("digits.model@1")) +# write_model_row(path, row, DigitsRowV1SchemaVersion()) # We don't run this every time, since we want to test that we can continue to deserialize previously saved out weights. -table = Legolas.read(path) -roundtripped = DigitsRowV1(only(Tables.rows(table))) +roundtripped = read_model_row(path) @test roundtripped isa DigitsRowV1 @test roundtripped.config isa DigitsConfigV1 diff --git a/examples/test.digits-model.arrow b/examples/test.digits-model.arrow index 43cc445..6916714 100644 Binary files a/examples/test.digits-model.arrow and b/examples/test.digits-model.arrow differ diff --git a/src/LegolasFlux.jl b/src/LegolasFlux.jl index 4075b94..e5ffb82 100644 --- a/src/LegolasFlux.jl +++ b/src/LegolasFlux.jl @@ -95,27 +95,34 @@ end ##### """ - write_model_row(io_or_path; kwargs...) + write_model_row(io_or_path, row[, schema=ModelV1SchemaVersion()]; kwargs...) A light wrapper around `Legolas.write` to write a table with a single row. `kwargs` are forwarded to an internal invocation of `Arrow.write`. """ -function write_model_row(io_or_path, row; kwargs...) - return Legolas.write(io_or_path, [row], ModelV1SchemaVersion(); validate=true, kwargs...) +function write_model_row(io_or_path, row, + schema::Legolas.SchemaVersion=ModelV1SchemaVersion(); + kwargs...) + return Legolas.write(io_or_path, [row], schema; validate=true, kwargs...) end """ - read_model_row(io_or_path) -> ModelV1 + read_model_row(io_or_path) -> Legolas.AbstractRecord A light wrapper around `Legolas.read` to retrieve -a `ModelV1` from a table with a single row, such -as the output of [`write_model_row`](@ref)`. +an `AbstractRecord` from a table with a single row, such +as the output of [`write_model_row`](@ref)`. The schema version +is inferred from the `Arrow.Table` (as with `Legolas.read`). """ function read_model_row(io_or_path) table = Legolas.read(io_or_path; validate=true) - rows = ModelV1.(Tables.rows(table)) - return only(rows) + # because we used validate=true above, we know that we can extract a usable + # SchemaVersion from the table. + sv = Legolas.extract_schema_version(table) + RecordType = Legolas.record_type(sv) + row = only(Tables.rows(table)) + return RecordType(row) end include("functors.jl") diff --git a/test/runtests.jl b/test/runtests.jl index ae90da3..f0b896b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,12 @@ using LegolasFlux using Test -using Flux, LegolasFlux +using Flux using LegolasFlux: Weights, FlatArray, ModelV1 using Flux: params using Arrow using Random using StableRNGs -using Legolas: @version, @schema +using Legolas: Legolas, @version, @schema function make_my_model() return Chain(Dense(1, 10), Dense(10, 10), Dense(10, 1)) end