diff --git a/src/core.jl b/src/core.jl index d94d9f22..3dcb3fbe 100644 --- a/src/core.jl +++ b/src/core.jl @@ -36,7 +36,7 @@ function train!(loss, penalty, chain, optimiser, X, y) parameters = Flux.params(chain) gs = Flux.gradient(parameters) do yhat = chain(X[i]) - batch_loss = loss(yhat, y[i]) + penalty(parameters) + batch_loss = loss(yhat, y[i]) + penalty(parameters)/n_batches training_loss += batch_loss return batch_loss end @@ -96,7 +96,7 @@ function fit!(loss, penalty, chain, optimiser, epochs, verbosity, X, y) parameters = Flux.params(chain) losses = (loss(chain(X[i]), y[i]) + - penalty(parameters) for i in 1:n_batches) + penalty(parameters)/n_batches for i in 1:n_batches) history = [mean(losses),] for i in 1:epochs