Skip to content

Commit

Permalink
Merge pull request #214 from mohamed82008/mt/minibatch
Browse files Browse the repository at this point in the history
Divide penalty by n_batches
  • Loading branch information
ablaom authored Sep 22, 2022
2 parents 61c3801 + 98394e1 commit b39d5ae
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function train!(loss, penalty, chain, optimiser, X, y)
parameters = Flux.params(chain)
gs = Flux.gradient(parameters) do
yhat = chain(X[i])
batch_loss = loss(yhat, y[i]) + penalty(parameters)
batch_loss = loss(yhat, y[i]) + penalty(parameters)/n_batches
training_loss += batch_loss
return batch_loss
end
Expand Down Expand Up @@ -96,7 +96,7 @@ function fit!(loss, penalty, chain, optimiser, epochs, verbosity, X, y)

parameters = Flux.params(chain)
losses = (loss(chain(X[i]), y[i]) +
penalty(parameters) for i in 1:n_batches)
penalty(parameters)/n_batches for i in 1:n_batches)
history = [mean(losses),]

for i in 1:epochs
Expand Down

0 comments on commit b39d5ae

Please sign in to comment.