From 29a09f8a86086455943099c8e87524e2f34cbba3 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 1 Aug 2023 16:55:41 +0200 Subject: [PATCH 1/4] first go at this --- src/core.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/core.jl b/src/core.jl index cca5a145..ee362abc 100644 --- a/src/core.jl +++ b/src/core.jl @@ -31,18 +31,18 @@ Update the parameters of a Flux `chain`, where: """ function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) + opt_state = Flux.setup(optimiser, chain) loss = model.loss n_batches = length(y) training_loss = zero(Float32) + parameters = Flux.params(chain) for i in 1:n_batches - parameters = Flux.params(chain) - gs = Flux.gradient(parameters) do - yhat = chain(X[i]) - batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches - training_loss += batch_loss - return batch_loss + batch_loss, gs = Flux.withgradient(chain) do m + yhat = m(X[i]) + loss(yhat, y[i]) + penalty(parameters) / n_batches end - Flux.update!(optimiser, parameters, gs) + training_loss += batch_loss + Flux.update!(opt_state, chain, gs[1]) end return training_loss / n_batches end From 8eaec2c4194af585dae5f642cb240c14a06df085 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 1 Aug 2023 17:02:04 +0200 Subject: [PATCH 2/4] first go at this --- src/core.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/core.jl b/src/core.jl index ee362abc..9d1a3876 100644 --- a/src/core.jl +++ b/src/core.jl @@ -39,7 +39,8 @@ function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) for i in 1:n_batches batch_loss, gs = Flux.withgradient(chain) do m yhat = m(X[i]) - loss(yhat, y[i]) + penalty(parameters) / n_batches + pen = penalty(parameters) / n_batches + loss(yhat, y[i]) + pen end training_loss += batch_loss Flux.update!(opt_state, chain, gs[1]) From 7a960ef45a7d4dd6177f47e878169dedc5602086 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 1 Aug 2023 19:49:13 +0200 Subject: [PATCH 3/4] let's see if my device is just acting up --- src/core.jl | 5 ++--- src/penalizers.jl | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/core.jl b/src/core.jl index 9d1a3876..8e2d3eca 100644 --- a/src/core.jl +++ b/src/core.jl @@ -35,12 +35,11 @@ function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) loss = model.loss n_batches = length(y) training_loss = zero(Float32) - parameters = Flux.params(chain) for i in 1:n_batches batch_loss, gs = Flux.withgradient(chain) do m yhat = m(X[i]) - pen = penalty(parameters) / n_batches - loss(yhat, y[i]) + pen + reg = penalty(Flux.params(chain)) / n_batches + loss(yhat, y[i]) + reg end training_loss += batch_loss Flux.update!(opt_state, chain, gs[1]) diff --git a/src/penalizers.jl b/src/penalizers.jl index ccea2d2d..3101d31c 100644 --- a/src/penalizers.jl +++ b/src/penalizers.jl @@ -1,7 +1,6 @@ # Note (1). See # https://discourse.julialang.org/t/weight-regularisation-which-iterates-params-m-in-flux-mutating-arrays-is-not-supported/64314 - """ Penalizer(λ, α) Returns a callable object `penalizer` for evaluating regularization From c534477e2e46d9a36245b525e5f3d48190af28f8 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Fri, 11 Aug 2023 11:29:42 +0200 Subject: [PATCH 4/4] still erroring --- Project.toml | 2 ++ src/core.jl | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 95084c47..0eb2de18 100644 --- a/Project.toml +++ b/Project.toml @@ -8,8 +8,10 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/core.jl b/src/core.jl index 8e2d3eca..f29f0e4c 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1,5 +1,8 @@ ## EXPOSE OPTIMISERS TO MLJ (for eg, tuning) +using Functors +using Optimisers + # make the optimiser structs "transparent" so that their field values # are exposed by calls to MLJ.params: MLJModelInterface.istransparent(m::Flux.Optimise.AbstractOptimiser) = true @@ -38,8 +41,9 @@ function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) for i in 1:n_batches batch_loss, gs = Flux.withgradient(chain) do m yhat = m(X[i]) - reg = penalty(Flux.params(chain)) / n_batches - loss(yhat, y[i]) + reg + l = loss(yhat, y[i]) + reg = Functors.fmap(penalty, m; exclude=Optimisers.isnumeric) + l + reg / n_batches end training_loss += batch_loss Flux.update!(opt_state, chain, gs[1])