From 98394e1b5d10b44612b8a2d1dc5dba4bf90c7245 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Thu, 22 Sep 2022 22:09:36 +1000 Subject: [PATCH] divide penalty by n_batches --- src/core.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core.jl b/src/core.jl index d94d9f22..3dcb3fbe 100644 --- a/src/core.jl +++ b/src/core.jl @@ -36,7 +36,7 @@ function train!(loss, penalty, chain, optimiser, X, y) parameters = Flux.params(chain) gs = Flux.gradient(parameters) do yhat = chain(X[i]) - batch_loss = loss(yhat, y[i]) + penalty(parameters) + batch_loss = loss(yhat, y[i]) + penalty(parameters)/n_batches training_loss += batch_loss return batch_loss end @@ -96,7 +96,7 @@ function fit!(loss, penalty, chain, optimiser, epochs, verbosity, X, y) parameters = Flux.params(chain) losses = (loss(chain(X[i]), y[i]) + - penalty(parameters) for i in 1:n_batches) + penalty(parameters)/n_batches for i in 1:n_batches) history = [mean(losses),] for i in 1:epochs