Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regularisation looks to slow down gradient function by factor 500 #2253

Closed
AndyAndDevid opened this issue May 4, 2023 · 3 comments
Closed

Comments

@AndyAndDevid
Copy link

AndyAndDevid commented May 4, 2023

I'm running a gradient with and without regularization in a AMD Ryzen 7 with RTX3060 GPU, Flux v0.13.15, CUDA v4.2.0, julia v1.9.0-rc3 (April 26, 2023). Independently if I use the GPU or not, when I do regularization, the time taken increases by factor between 400 (gpu) and 20000(cpu).
In a more complex model, after some iterations, the gradient looks also to increase its processing time

Here follows the example. Both Flux.gradient and Flux.withgradient shows similar performance at same conditions.

using Flux
using CUDA

function regGrad()
  ni=20
  no=4
  model = Chain(Dense(ni, 50), Dense(50, 8), Dense(8, no))
  model = model #|> gpu

  input = rand(ni) #|> gpu
  label = rand(no) #|> gpu

  pen_l2(x::AbstractArray) = sum(abs2, x) / 2

  for i in 1:10
    startTime = time_ns()
    grads = Flux.gradient(model) do m
      result = m(input)
      #penalty = sum(pen_l2, Flux.params(m))
      Flux.Losses.mse(result, label) #+ 0.42 * penalty
    end
    Dtime_grad = time_ns() - startTime
    println("without regularization: ", Dtime_grad/1000000)
  end

  for i in 1:10
    startTime = time_ns()
    loss, grads = Flux.withgradient(model) do m
      result = m(input)
      penalty = sum(pen_l2, Flux.params(m))
      Flux.Losses.mse(result, label) + 0.42 * penalty
    end
    Dtime_wgrad = time_ns() - startTime
    println("with regularization: ", Dtime_wgrad/1000000)
  end
end


Results in ms:
without regularization: 0.0762
without regularization: 0.0275
without regularization: 0.0287
without regularization: 0.0325
without regularization: 0.0252
without regularization: 0.0215
without regularization: 0.0254
without regularization: 0.029
without regularization: 0.0215
without regularization: 0.024
with regularization: 503.5599
with regularization: 513.9299
with regularization: 512.6677
with regularization: 517.6622
with regularization: 519.182
with regularization: 527.7295
with regularization: 513.0565
with regularization: 529.4732
with regularization: 541.2436
with regularization: 549.6674

am I doing something wrong?

@christiangnrd
Copy link
Contributor

Take a look at #2211 and #2040 to see if it's the same issue. #2211 has some troubleshooting step you might want to follow.

Also, you should surround your code in "```" so that it gets displayed properly.

Like this:
```
println("Hello, World!")
```
will show up as:

println("Hello, World!")

@AndyAndDevid
Copy link
Author

Yes, definitive it looks to be the same issue as [https://github.com//issues/2211] and [https://github.com//issues/2040].
Thanks! Hopefully it will be fixed soon!

@darsnack
Copy link
Member

darsnack commented May 4, 2023

I have posted a workaround until the fix: #2040 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants