diff --git a/examples/Project.toml b/examples/Project.toml index 142443c..27854b3 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,7 +1,11 @@ [deps] +Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" LegolasFlux = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/digits.jl b/examples/digits.jl index 611813f..abd82bc 100644 --- a/examples/digits.jl +++ b/examples/digits.jl @@ -8,6 +8,8 @@ using StableRNGs using Flux: onehotbatch, onecold, crossentropy, throttle using Base.Iterators: repeated, partition using Legolas, LegolasFlux +using Zygote +using Optimisers: Optimisers # This should store all the information needed # to construct the model. @@ -110,17 +112,47 @@ function accuracy(m, x, y) end function train_model!(m; N = N_train) - loss = (x, y) -> crossentropy(m(x), y) - opt = ADAM() + state = Optimisers.setup(Optimisers.ADAM(), m) # just once evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5) - Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt; cb=evalcb) - return accuracy(m, tX, tY) + for d in Iterators.take(train, N) + m̄, _ = gradient(m, d[1]) do m, x + crossentropy(m(x), d[2]) + end + state, m = Optimisers.update(state, m, m̄); + evalcb() + end + return accuracy(m, tX, tY), state end m = DigitsModel() # increase N to actually train more than a tiny amount -acc = train_model!(m; N=10) +acc, state = train_model!(m; N=10) + +## +# + +using Arrow, Test + +macro serialize_as_record(T) + name = :(Symbol("JuliaLang.", @__MODULE__, ".", string(parentmodule($T), '.', nameof($T)))) + return quote + Arrow.ArrowTypes.arrowname(::Type{$T}) = $name + Arrow.ArrowTypes.ArrowType(::Type{$T}) = fieldtypes($T) + Arrow.ArrowTypes.toarrow(obj::$T) = ntuple(i -> getfield(obj, i), fieldcount($T)) + Arrow.ArrowTypes.JuliaType(::Val{$name}, ::Any) = $T + Arrow.ArrowTypes.fromarrow(::Type{$T}, args...) = $T(args...) + end +end + +@serialize_as_record Optimisers.ADAM +@serialize_as_record Optimisers.Leaf + +Arrow.tobuffer( [(; obj=state)]; maxdepth=50) +state2 = Arrow.Table(Arrow.tobuffer( [(; obj=state)]; maxdepth=50)).obj[1] + +# +## # Let's serialize out the weights into a `DigitsRow`. # We could save this here with `write_model_row`.