diff --git a/Project.toml b/Project.toml index 95084c47..0eb2de18 100644 --- a/Project.toml +++ b/Project.toml @@ -8,8 +8,10 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/core.jl b/src/core.jl index cca5a145..f29f0e4c 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1,5 +1,8 @@ ## EXPOSE OPTIMISERS TO MLJ (for eg, tuning) +using Functors +using Optimisers + # make the optimiser structs "transparent" so that their field values # are exposed by calls to MLJ.params: MLJModelInterface.istransparent(m::Flux.Optimise.AbstractOptimiser) = true @@ -31,18 +34,19 @@ Update the parameters of a Flux `chain`, where: """ function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) + opt_state = Flux.setup(optimiser, chain) loss = model.loss n_batches = length(y) training_loss = zero(Float32) for i in 1:n_batches - parameters = Flux.params(chain) - gs = Flux.gradient(parameters) do - yhat = chain(X[i]) - batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches - training_loss += batch_loss - return batch_loss + batch_loss, gs = Flux.withgradient(chain) do m + yhat = m(X[i]) + l = loss(yhat, y[i]) + reg = Functors.fmap(penalty, m; exclude=Optimisers.isnumeric) + l + reg / n_batches end - Flux.update!(optimiser, parameters, gs) + training_loss += batch_loss + Flux.update!(opt_state, chain, gs[1]) end return training_loss / n_batches end diff --git a/src/penalizers.jl b/src/penalizers.jl index ccea2d2d..3101d31c 100644 --- a/src/penalizers.jl +++ b/src/penalizers.jl @@ -1,7 +1,6 @@ # Note (1). See # https://discourse.julialang.org/t/weight-regularisation-which-iterates-params-m-in-flux-mutating-arrays-is-not-supported/64314 - """ Penalizer(λ, α) Returns a callable object `penalizer` for evaluating regularization