Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong model update for BatchNorm for some specific synthax #123

Closed
jeremiedb opened this issue Dec 2, 2022 · 5 comments
Closed

Wrong model update for BatchNorm for some specific synthax #123

jeremiedb opened this issue Dec 2, 2022 · 5 comments
Labels
bug Something isn't working

Comments

@jeremiedb
Copy link

Package Version

v0.2.13

Julia Version

1.8.2

OS / Environment

Windows

Describe the bug

If using the opts, m = Optimisers.update!(opts, m, grads) within a loop with more than 1 iteration that doesn't return the model m when complete, the model is improperly updated.

Steps to Reproduce

Defining a minimal custom model that contains a BatchNorm, as well a toy data and mse loss:

using Optimisers
using Flux
using Flux: @functor
using Statistics
using Random: seed!

struct MyModel{N}
    bn::N
end
@functor MyModel
Flux.trainable(m::MyModel) = (bn = m.bn,)

function (m::MyModel)(x)
    fp = m.bn(x)
    return dropdims(mean(fp, dims = 1), dims = 1)
end

seed!(123)
seed!(123)
x1, x2 = rand(Float32, 3, 5), rand(Float32, 3, 5)
y1, y2 = rand(Float32, 5), rand(Float32, 5)
loss(m, x, y) = mean((vec(m(x)) .- y) .^ 2)

2 variations of a training loop (with and without opts, m assignements):

function fit_1!(m, loss, opts)
    seed!(123)
    for i = 1:2
        x, y = rand(Float32, 3, 5), rand(Float32, 5)
        grads = gradient(model -> loss(model, x, y), m)[1]
        opts, m = Optimisers.update!(opts, m, grads)
    end
    return nothing
end

function fit_2!(m, loss, opts)
    seed!(123)
    for i = 1:2
        x, y = rand(Float32, 3, 5), rand(Float32, 5)
        grads = gradient(model -> loss(model, x, y), m)[1]
        Optimisers.update!(opts, m, grads)
    end
    return nothing
end

m1 = MyModel(BatchNorm(3))
opt1 = Optimisers.Adam()
opts1 = Optimisers.setup(opt1, m1)

m2 = MyModel(BatchNorm(3))
opt = Optimisers.Adam()
opts2 = Optimisers.setup(opt, m2)

Expected Results

It would be expected that each loop result in identical models. However, it's not the case:

# first loop
julia> loss(m1, x1, y1)
0.07888316f0
julia> fit_1!(m1, loss, opts1)
julia> loss(m1, x1, y1)
0.08968687f0

# second loop
julia> loss(m2, x1, y1)
0.07888316f0
julia> fit_2!(m2, loss, opts2)
julia> loss(m2, x1, y1)
0.10831509f0

Note that if the loop had only a single iteration (for i in = 1:1), the models would then be identical.

Also, if the model is returned following the loop and assigned, then the 2 loops also behave the same:

function fit_1!(m, loss, opts)
    seed!(123)
    for i = 1:2
        x, y = rand(Float32, 3, 5), rand(Float32, 5)
        grads = gradient(model -> loss(model, x, y), m)[1]
        opts, m = Optimisers.update!(opts, m, grads)
    end
    return m
end

m1 = MyModel(BatchNorm(3))
opt1 = Optimisers.Adam()
opts1 = Optimisers.setup(opt1, m1)

julia> loss(m1, x1, y1)
0.07888316f0
julia> m1 = fit_1!(m1, loss, opts1);
julia> loss(m1, x1, y1)
0.10831509f0

The 0.10831509f0 is the same as the one obtained using fit_2!.

Observed Results

none

Relevant log output

none

@jeremiedb jeremiedb added the bug Something isn't working label Dec 2, 2022
@jeremiedb
Copy link
Author

Really not sure if #2122 might be related (considering BatchNorm is also the operator causing the issue)

@mcabbott
Copy link
Member

mcabbott commented Dec 3, 2022

Haven't had a chance to look. My guess is that somehow the mutation of the BatchNorm struct is lost by Optimisers. Would be interesting to know if FluxML/Flux.jl#2127 changes this, as it mutates the tracking arrays directly instead.

@skyleaworlder
Copy link
Contributor

Haven't had a chance to look. My guess is that somehow the mutation of the BatchNorm struct is lost by Optimisers. Would be interesting to know if FluxML/Flux.jl#2127 changes this, as it mutates the tracking arrays directly instead.

image

Add a point. I use the latest version of Flux and Optimisers and it seems that this bug was fixed.

@CarloLucibello
Copy link
Member

We should add a test to make sure we don't have regressions

@skyleaworlder
Copy link
Contributor

We should add a test to make sure we don't have regressions

I see the related test was added in Flux. Does Optimiser need a test about update!?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants