Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull request #2007 causes Flux.params() calls to not get cached #2040

Closed
christiangnrd opened this issue Aug 15, 2022 · 11 comments
Closed

Pull request #2007 causes Flux.params() calls to not get cached #2040

christiangnrd opened this issue Aug 15, 2022 · 11 comments

Comments

@christiangnrd
Copy link
Contributor

christiangnrd commented Aug 15, 2022

I'm really not sure where the breakage happens but I'm more than happy to run test if you need me to. I isolated the issue to #2007 with bisect.

Ever since I upgraded to 0.13.5, my variational convolutional autoencoder is not running on the gpu and does not display any error messages. I can see using nvidia-smi that things are properly being transferred to gpu memory, but when it comes to the actually computations, the gpu usage fluctuates between 0% and 2% (which I believe is the arrays being moved to and from main memory) instead of consistently going up to ~60% usage.

I tried a non convolutional variational autoencoder with the rest being mostly the same that was working normally, and I also tried models with convolutional layers without functors that also seemed to compute on gpu, so I believe that the combination of a struct with convolutional layers that is tagged with @functor is what prevents the math from happening on the gpu.

I'll do my best to be responsive.

@christiangnrd christiangnrd changed the title Pull request#2007 Pull request #2007 silently breaks gpu usage sometimes Aug 16, 2022
@mcabbott
Copy link
Member

It's not so easy to guess what's gone wrong. I presume your model contains transpose or adjoint matrices as parameters, which [email protected] will recurse into, instead of regarding them as leaf nodes. Maybe that's where to look to construct a MWE.

@christiangnrd
Copy link
Contributor Author

Here is an MWE. On my desktop, running it with v0.13.4 takes 2 seconds per epoch while running it with v0.13.5 takes a couple minutes per epoch.

using Flux
using Flux: @functor, flatten
using Flux.Losses: logitbinarycrossentropy
using Flux.Data: DataLoader
using MLDatasets
using ProgressMeter: ProgressMeter, Progress, next!
ProgressMeter.ijulia_behavior(:clear)
using Random
using MLUtils: unsqueeze

# load MNIST images and return loader
function get_data(batch_size, split=:train)
    xtrain, ytrain = MLDatasets.MNIST(split)[:]
    xtrain = unsqueeze(xtrain, 3)
    DataLoader((xtrain, ytrain), batchsize=batch_size, shuffle=true)
end

struct Encoder
    conv
    μ
    logσ
end
@functor Encoder
    
Encoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Encoder(
    Chain(
        Conv((3,3),1 => 32,relu,stride = 2, pad = 1),
        Conv((3,3),32 => 64,relu,stride = 2, pad = 1),
        flatten,
        Dense((input_dim ÷ (2*2))^2 * 64,hidden_dim,relu),
        ),
    # identity as activation function
    Dense(hidden_dim,latent_dim), # μ
    Dense(hidden_dim,latent_dim), # logσ
)

function (encoder::Encoder)(x)
    h = encoder.conv(x)
    encoder.μ(h), encoder.logσ(h)
end

Decoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Chain(
    Dense(latent_dim,(input_dim ÷ (2*2))^2 * 64,relu),
    x -> reshape(x,(7,7,64,:)),

    # note SamePad() is not possible here
    ConvTranspose((3,3),64 => 64,relu,stride=2, pad = SamePad()),
    ConvTranspose((3,3),64 => 32,relu,stride=2, pad = SamePad()),
    ConvTranspose((3,3),32 => 1, pad = 1)
)

function reconstuct(encoder, decoder, x, device)
    μ, logσ = encoder(x)
    z = μ + device(randn(Float32, size(logσ))) .* exp.(logσ)
    μ, logσ, decoder(z)
end

function model_loss(encoder, decoder, λ, x, device)
    μ, logσ, decoder_z = reconstuct(encoder, decoder, x, device)
    len = size(x)[end]
    # KL-divergence
    kl_q_p = 0.5f0 * sum(@. (exp(2f0 * logσ) + μ^2 -1f0 - 2f0 * logσ)) / len

    logp_x_z = logitbinarycrossentropy(decoder_z, x, agg=sum) / len
    # regularization
    reg = λ * sum(x->sum(x.^2), Flux.params(decoder))
    
    logp_x_z + kl_q_p + reg
end

η = 1e-3                # learning rate
λ = 0.01f0              # regularization paramater
epochs = 20             # number of epochs
cuda = true             # use GPU
input_dim = 28        # image size
latent_dim = 2          # latent dimension
hidden_dim = 16        # hidden dimension

# GPU config
if cuda && CUDA.has_cuda()
    device = gpu
    @info "Training on GPU"
else
    device = cpu
    @info "Training on CPU"
end

# load MNIST images
loader = get_data(1024)

# initialize encoder and decoder
encoder = Encoder(input_dim, latent_dim, hidden_dim) |> device
decoder = Decoder(input_dim, latent_dim, hidden_dim) |> device

input, _ = first(loader)

# ADAM optimizer
opt = ADAM(η)

# parameters
ps = Flux.params(encoder.conv,encoder.μ, encoder.logσ, decoder)

# training
# train_steps = 0
@info "Start Training, total $(epochs) epochs"
for epoch = 1:epochs
    progress = Progress(length(loader))


     for (x, _) in loader 
        loss, back = Flux.pullback(() -> model_loss(encoder, decoder, λ, x |> device, device), ps)
        grad = back(1f0)
        Flux.Optimise.update!(opt, ps, grad)

        # train_steps += 1
        next!(progress; showvalues=[(:epoch, epoch), (:loss, loss)])

    end

end

@ToucheSir
Copy link
Member

I can confirm 0.13.5 is slower (around 2-3X) per step over 0.13.4. However, both do make use of the GPU on my machine and neither runs an entire epoch in only 2s!

One suspicious find on 0.13.5 which may be contributing. When I @time each step, I see a significant amount of compilation time even after the first step. Here's a snippet from the beginning:

118.301900 seconds (260.05 M allocations: 13.222 GiB, 2.63% gc time, 67.30% compilation time: 0% of which was recompilation)
  0.602355 seconds (155.17 k allocations: 10.094 MiB, 11.42% gc time, 22.53% compilation time)
  1.051859 seconds (265.48 k allocations: 14.941 MiB, 34.92% gc time, 18.84% compilation time)
  0.672971 seconds (325.69 k allocations: 17.240 MiB, 4.61% gc time, 23.60% compilation time)
  0.932188 seconds (411.05 k allocations: 20.821 MiB, 30.46% gc time, 18.00% compilation time)
  0.777822 seconds (496.12 k allocations: 24.397 MiB, 2.35% gc time, 29.21% compilation time)
  0.769096 seconds (581.19 k allocations: 27.964 MiB, 2.32% gc time, 36.30% compilation time)
  0.758847 seconds (666.48 k allocations: 31.538 MiB, 4.06% gc time, 26.60% compilation time)
  1.085506 seconds (751.89 k allocations: 35.120 MiB, 28.85% gc time, 19.79% compilation time)
  0.722599 seconds (836.76 k allocations: 38.708 MiB, 1.73% gc time, 30.56% compilation time)
  0.809376 seconds (922.12 k allocations: 42.270 MiB, 2.30% gc time, 29.58% compilation time)

0.13.4 does not exhibit this continual compilation, but does still incur some compilation in steps 2-3:

103.966379 seconds (257.45 M allocations: 13.102 GiB, 2.80% gc time, 74.92% compilation time: 0% of which was recompilation)
  0.463840 seconds (14.39 k allocations: 4.098 MiB, 10.50% gc time, 1.62% compilation time)
  0.501627 seconds (39.38 k allocations: 5.366 MiB, 5.73% gc time, 4.95% compilation time)

@ToucheSir
Copy link
Member

Reduced the repeated compilation to taking the gradient of just the regularization term λ * sum(x->sum(x.^2), Flux.params(decoder)) for even a simple model (1 Dense layer). The culprit is 0b62a91#diff-fb7b52bcd5616e0bebd43199ba13ba86729cd6a0ea17598ec355c3b3fe47c521L39-R48, but I don't understand why pullbacks aren't being cached after the first compilation (in both 0.13.5 and for the first couple of iterations on 0.13.4) 😕.

@christiangnrd
Copy link
Contributor Author

I can confirm that when I make no changes other than removing the regularization line in my original code, the performance is back to what it was in 0.13.4.

Is this something that can be fixed in this repo, or does the issue lie somewhere else?

Also, any workarounds that would avoid the constant recompilation while keeping the regularization while waiting for a fix?

@MariusDrulea
Copy link

MariusDrulea commented Dec 27, 2022

Unfortunately, the problem is present again in Flux 0.13.10, so we have to reopen this task.

In the following MWE, loss_slow compiles at every iteration. Additionally, the runtime and the memory usage at each iteration are increasing. It looks like loss_slow causes Zygote to continuously accumulate some data. In the vae_mnist example, the runtime starts at 4 minutes/epoch and reaches 1.5 hours per epoch.

The equivalent loss_explicit function behaves as expected.

using Flux
using Flux: norm

model = Dense(2, 2)

loss_slow(m) = sum(p->norm(p), Flux.params(m))
loss_explicit(m) = norm(m.weight) + norm(m.bias)

for i in 1:10
    @time ∇m_slow = gradient(m->loss_slow(m), model)    
end

for i in 1:10
    @time ∇m_explicit = gradient(m->loss_explicit(m), model)    
end

Here is the output:

loss_slow:
 23.518778 seconds (62.17 M allocations: 3.153 GiB, 3.73% gc time, 99.94% compilation time)
  0.018303 seconds (4.03 k allocations: 183.281 KiB, 93.40% compilation time)
  0.018860 seconds (5.14 k allocations: 231.125 KiB, 93.63% compilation time)
  0.019585 seconds (6.24 k allocations: 281.562 KiB, 91.42% compilation time)
  0.019242 seconds (7.33 k allocations: 324.969 KiB, 92.79% compilation time)
  0.019103 seconds (8.44 k allocations: 376.188 KiB, 90.87% compilation time)
  0.019514 seconds (9.53 k allocations: 419.500 KiB, 91.37% compilation time)
  0.019786 seconds (10.63 k allocations: 467.250 KiB, 90.60% compilation time)
  0.022090 seconds (11.73 k allocations: 514.031 KiB, 91.70% compilation time)
  0.019207 seconds (12.83 k allocations: 561.297 KiB, 90.98% compilation time)
  0.038078 seconds (73.32 k allocations: 3.669 MiB, 99.70% compilation time)

loss_explicit:
  0.000017 seconds (29 allocations: 1.766 KiB)
  0.000015 seconds (29 allocations: 1.766 KiB)
  0.000006 seconds (29 allocations: 1.766 KiB)
  0.000005 seconds (29 allocations: 1.766 KiB)
  0.000005 seconds (29 allocations: 1.766 KiB)
  0.000006 seconds (29 allocations: 1.766 KiB)
  0.000006 seconds (29 allocations: 1.766 KiB)
  0.000004 seconds (29 allocations: 1.766 KiB)

@ToucheSir ToucheSir reopened this Dec 27, 2022
@svilupp
Copy link
Contributor

svilupp commented Jan 28, 2023

I was wondering if there any workarounds that one can do to have regularization in the loss and avoid this issue? Eg, some Zygote tricks in how the loss is constructed?

It seems that the only solution right now is to pin to 0.13.4, right?

@ToucheSir
Copy link
Member

The best and most future-proof solution is to use explicit params for the regularization term as shown above, but we don't currently have nice helper functionality for that. If you're using implicit params, you can call params outside the loss (good idea either way) and iterate over it inside.

@darsnack
Copy link
Member

darsnack commented May 4, 2023

we don't currently have nice helper functionality for that

The issue to track for the helper is FluxML/Optimisers.jl#57. Until then, here is a snippet that should do that same for simple penalties like L2.

using Flux
using Functors

penalty(x::AbstractArray) = sum(x.^2) # example penalty

# further down
grads = Flux.gradient(model) do m
    loss = # ...
    reg = Functors.fmap(penalty, m; exclude = Flux.trainable)
    return loss + lambda * reg
end

@mcabbott
Copy link
Member

mcabbott commented Aug 1, 2023

exclude = Flux.trainable can't be right here, it doesn't return a Bool.

It really needs some trainablewalk. exclude=Optimisers.isnumeric ought to run, but will include any non-trainable parameter arrays.

@CarloLucibello
Copy link
Member

Now we can apply regularization using Flux.trainables(model) as shown in https://fluxml.ai/Flux.jl/stable/guide/training/training/#Regularisation

Closing as we don't support params anymore

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants