Skip to content

Commit

Permalink
make (read|write)_model_row support custom schemas (#26)
Browse files Browse the repository at this point in the history
* read/write model row uses declared schema (kwarg or table metadata)

* update example with read/write model row

* further readme updates

* bump patch version (new feature, backwards compatible)

* missed that examples are actually run in tests

* Update README.md
  • Loading branch information
kleinschmidt authored Feb 21, 2023
1 parent 9ee2257 commit 8707055
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LegolasFlux"
uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4"
authors = ["Beacon Biosignals, Inc."]
version = "0.2.0"
version = "0.2.1"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ a schema extension of the `legolas-flux.model` schema:
using Legolas, LegolasFlux
using Legolas: @schema, @version
@schema "digits-model" DigitsRow
@version DigitsRowV1 begin
@version DigitsRowV1 > ModelV1 begin
# re-declare this ModelV1 field as parametric for this schema as well
weights::(<:Union{Missing,Weights})
epoch::Union{Missing, Int}
accuracy::Union{Missing, Float32}
commit_sha::Union{Missing, String}
Expand All @@ -81,6 +83,12 @@ end
Now I can use a `DigitsRowV1` much like LegolasFlux's `ModelV1`. It has the same required `weights` column and optional `architecture_version` column, as well as the additional `epoch`, `accuracy`, and `commit_sha` columns. As a naming convention,
one might name files produced by this row as e.g. `training_run.digits.model.arrow`.

When writing out a `DigitsRowV1`, I'll pass the schema version like so
```julia
write_model_row(path, my_digits_row, DigitsRowV1SchemaVersion())
```
so that later, when I call `read_model_row` on this path, I'll get back a `DigitsRowV1` instance.

Note in this example the schema is called `digits.model` instead of just say `digits`, since the package Digits might want to
create other Legolas schemas as well at some point.

Expand Down
1 change: 1 addition & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ LegolasFlux = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
6 changes: 3 additions & 3 deletions examples/digits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Legolas, LegolasFlux
using Legolas: @schema, @version
using LegolasFlux: Weights
using Tables

# This should store all the information needed
Expand Down Expand Up @@ -163,10 +164,9 @@ output2 = m2(input)

path = joinpath(pkgdir(LegolasFlux), "examples", "test.digits-model.arrow")
# The saved weights in this repo were generated by running the command:
# Legolas.write(path, [row], Legolas.Schema("digits.model@1"))
# write_model_row(path, row, DigitsRowV1SchemaVersion())
# We don't run this every time, since we want to test that we can continue to deserialize previously saved out weights.
table = Legolas.read(path)
roundtripped = DigitsRowV1(only(Tables.rows(table)))
roundtripped = read_model_row(path)
@test roundtripped isa DigitsRowV1
@test roundtripped.config isa DigitsConfigV1

Expand Down
Binary file modified examples/test.digits-model.arrow
Binary file not shown.
23 changes: 15 additions & 8 deletions src/LegolasFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,34 @@ end
#####

"""
write_model_row(io_or_path; kwargs...)
write_model_row(io_or_path, row[, schema=ModelV1SchemaVersion()]; kwargs...)
A light wrapper around `Legolas.write` to write a table
with a single row. `kwargs` are forwarded to an internal
invocation of `Arrow.write`.
"""
function write_model_row(io_or_path, row; kwargs...)
return Legolas.write(io_or_path, [row], ModelV1SchemaVersion(); validate=true, kwargs...)
function write_model_row(io_or_path, row,
schema::Legolas.SchemaVersion=ModelV1SchemaVersion();
kwargs...)
return Legolas.write(io_or_path, [row], schema; validate=true, kwargs...)
end

"""
read_model_row(io_or_path) -> ModelV1
read_model_row(io_or_path) -> Legolas.AbstractRecord
A light wrapper around `Legolas.read` to retrieve
a `ModelV1` from a table with a single row, such
as the output of [`write_model_row`](@ref)`.
an `AbstractRecord` from a table with a single row, such
as the output of [`write_model_row`](@ref)`. The schema version
is inferred from the `Arrow.Table` (as with `Legolas.read`).
"""
function read_model_row(io_or_path)
table = Legolas.read(io_or_path; validate=true)
rows = ModelV1.(Tables.rows(table))
return only(rows)
# because we used validate=true above, we know that we can extract a usable
# SchemaVersion from the table.
sv = Legolas.extract_schema_version(table)
RecordType = Legolas.record_type(sv)
row = only(Tables.rows(table))
return RecordType(row)
end

include("functors.jl")
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
using LegolasFlux
using Test
using Flux, LegolasFlux
using Flux
using LegolasFlux: Weights, FlatArray, ModelV1
using Flux: params
using Arrow
using Random
using StableRNGs
using Legolas: @version, @schema
using Legolas: Legolas, @version, @schema
function make_my_model()
return Chain(Dense(1, 10), Dense(10, 10), Dense(10, 1))
end
Expand Down

2 comments on commit 8707055

@kleinschmidt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/78217

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" 870705520691f599d8e70b1c46eb21998641eede
git push origin v0.2.1

Please sign in to comment.