From d07855fb1bf5b6cdb7565967cb00139481eed90a Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Fri, 13 May 2022 19:40:47 +0200 Subject: [PATCH 1/2] attempt --- examples/Project.toml | 4 ++++ examples/digits.jl | 41 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/examples/Project.toml b/examples/Project.toml index 142443c..27854b3 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,7 +1,11 @@ [deps] +Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" LegolasFlux = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/digits.jl b/examples/digits.jl index 611813f..068177d 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -8,6 +8,8 @@ using StableRNGs using Flux: onehotbatch, onecold, crossentropy, throttle using Base.Iterators: repeated, partition using Legolas, LegolasFlux +using Zygote +using Optimisers: Optimisers # This should store all the information needed # to construct the model. @@ -110,17 +112,46 @@ function accuracy(m, x, y) end function train_model!(m; N = N_train) - loss = (x, y) -> crossentropy(m(x), y) - opt = ADAM() + state = Optimisers.setup(Optimisers.ADAM(), m) # just once evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5) - Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt; cb=evalcb) - return accuracy(m, tX, tY) + for d in Iterators.take(train, N) + m̄, _ = gradient(m, d[1]) do m, x + crossentropy(m(x), d[2]) + end + state, m = Optimisers.update(state, m, m̄); + evalcb() + end + return accuracy(m, tX, tY), state end m = DigitsModel() # increase N to actually train more than a tiny amount -acc = train_model!(m; N=10) +acc, state = train_model!(m; N=10) + +using Arrow, Test + +macro serialize_as_record(T) + name = :(Symbol("JuliaLang.", @__MODULE__, ".", string(parentmodule($T), '.', nameof($T)))) + return quote + Arrow.ArrowTypes.arrowname(::Type{$T}) = $name + Arrow.ArrowTypes.ArrowType(::Type{$T}) = fieldtypes($T) + Arrow.ArrowTypes.toarrow(obj::$T) = ntuple(i -> getfield(obj, i), fieldcount($T)) + Arrow.ArrowTypes.JuliaType(::Val{$name}, ::Any) = $T + Arrow.ArrowTypes.fromarrow(::Type{$T}, args...) = $T(args...) + end +end + +@serialize_as_record Optimisers.ADAM +@serialize_as_record Optimisers.Leaf + +Arrow.tobuffer( [(; obj=state)]; maxdepth=50) +state2 = Arrow.Table(Arrow.tobuffer( [(; obj=state)]; maxdepth=50)).obj[1] + + + + +LegolasFlux.load_weights!(state2, state) # Let's serialize out the weights into a `DigitsRow`. # We could save this here with `write_model_row`. From 3c54b5f37a8e3baf0fc53261585882ba1aa167b2 Mon Sep 17 00:00:00 2001 From: Eric Hanson <5846501+ericphanson@users.noreply.github.com> Date: Fri, 13 May 2022 19:44:15 +0200 Subject: [PATCH 2/2] wip --- examples/digits.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/digits.jl b/examples/digits.jl index 068177d..abd82bc 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -129,6 +129,9 @@ m = DigitsModel() # increase N to actually train more than a tiny amount acc, state = train_model!(m; N=10) +## +# + using Arrow, Test macro serialize_as_record(T) @@ -148,10 +151,8 @@ end Arrow.tobuffer( [(; obj=state)]; maxdepth=50) state2 = Arrow.Table(Arrow.tobuffer( [(; obj=state)]; maxdepth=50)).obj[1] - - - -LegolasFlux.load_weights!(state2, state) +# +## # Let's serialize out the weights into a `DigitsRow`. # We could save this here with `write_model_row`.