diff --git a/src/core.jl b/src/core.jl index 32ec150c..dbb450fa 100644 --- a/src/core.jl +++ b/src/core.jl @@ -44,13 +44,17 @@ vector of arrays where the last dimension is the batch size. `y` is the target observation vector. """ function train!(loss_func, parameters, optimiser, X, y) - for i=1:length(X) + n_batches = length(y) + training_loss = zero(Float32) + for i in 1:n_batches gs = Flux.gradient(parameters) do - training_loss = loss_func(X[i], y[i]) - return training_loss + batch_loss = loss_func(X[i], y[i]) + training_loss += batch_loss + return batch_loss end Flux.update!(optimiser, parameters, gs) end + return training_loss/n_batches end @@ -124,15 +128,15 @@ function fit!(chain, optimiser, loss, epochs, loss_func(x, y) = loss(chain(x), y) # initiate history: - prev_loss = mean(loss_func(X[i], y[i]) for i=1:length(X)) - history = [prev_loss,] + n_batches = length(y) + + training_loss = mean(loss_func(X[i], y[i]) for i in 1:n_batches) + history = [training_loss,] for i in 1:epochs # We're taking data in a Flux-fashion. # @show i rand() - train!(loss_func, Flux.params(chain), optimiser, X, y) - current_loss = - mean(loss_func(X[i], y[i]) for i=1:length(X)) + current_loss = train!(loss_func, Flux.params(chain), optimiser, X, y) verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" push!(history, current_loss) diff --git a/test/builders.jl b/test/builders.jl index bddcdf61..2892323e 100644 --- a/test/builders.jl +++ b/test/builders.jl @@ -1,5 +1,5 @@ # to control chain initialization: -myinit(n, m) = reshape(float(1:n*m), n , m) +myinit(n, m) = reshape(convert(Vector{Float32}, (1:n*m)), n , m) mutable struct TESTBuilder <: MLJFlux.Builder end MLJFlux.build(builder::TESTBuilder, rng, n_in, n_out) = @@ -10,7 +10,8 @@ MLJFlux.build(builder::TESTBuilder, rng, n_in, n_out) = # data: n = 100 d = 5 - Xmat = rand(Float64, n, d) + Xmat = rand(Float32, n, d) +# Xmat = fill(one(Float32), n, d) X = MLJBase.table(Xmat); y = X.x1 .^2 + X.x2 .* X.x3 - 4 * X.x4 @@ -31,6 +32,7 @@ MLJFlux.build(builder::TESTBuilder, rng, n_in, n_out) = pretraining_yhat = Xmat*chain0' |> vec @test y isa Vector && pretraining_yhat isa Vector pretraining_loss_by_hand = MLJBase.l2(pretraining_yhat, y) |> mean + mean(((pretraining_yhat - y).^2)[1:2]) # compare: @test pretraining_loss ≈ pretraining_loss_by_hand