-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stop using implicit style differentiating #221
Comments
Relatedly, it would be nice if the MLJFlux models listed here https://github.com/FluxML/model-zoo#examples-elsewhere could be updated to use latest Flux, and avoid implicit gradients. Examples of similar upgrades: https://github.com/FluxML/model-zoo/issues?q=is%3Aclosed+label%3Aupdate+explicit In the end, Flux 0.14 did not drop support for implicit gradients, but 0.15 should. |
@pat-alt Would you have any time and interest in addressing this issue? |
That actually syncs well with some of my other outstanding issues and I think I'll have to address this very same thing in CounterfactualExplanations.jl soon. So yes, please feel free to assign to this one to me and I'll look at it in the coming weeks 👍 |
I have added a draft for this with very minor changes here #230: function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
opt_state = Flux.setup(optimiser, chain)
loss = model.loss
n_batches = length(y)
training_loss = zero(Float32)
parameters = Flux.params(chain)
for i in 1:n_batches
batch_loss, gs = Flux.withgradient(chain) do m
yhat = m(X[i])
pen = penalty(parameters) / n_batches
loss(yhat, y[i]) + pen
end
training_loss += batch_loss
Flux.update!(opt_state, chain, gs[1])
end
return training_loss / n_batches
end Currently, the following test fails: [ Info: regularization has an effect:
[ Info: acceleration = CPU1{Nothing}(nothing)
regularization has an effect (typename(CPU1)): Test Failed at /Users/patrickaltmeyer/code/MLJFlux.jl/test/integration.jl:25
Expression: !(loss2 ≈ loss3)
Evaluated: !(0.8354643267207931 ≈ 0.8354643267207931) I'm not quite sure what's happening. @mcabbott can you spot anything obviously wrong this? |
That's because the regularization term is still using implicit params. Something like FluxML/Flux.jl#2040 (comment) will be needed for explicit params. |
What is |
Thanks both!
Currently, penalty functions are explicitly defined callable objects in MLJFlux (see here). I saw the note on In any case, I can't really get either of the approaches you suggest to work in this particular case, so we may indeed want to rethink the implementation of the penalty functions, for example by using |
Can you elaborate? I'm not sure I understand why/how they wouldn't work. |
Sure! Moving the function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
opt_state = Flux.setup(optimiser, chain)
loss = model.loss
n_batches = length(y)
training_loss = zero(Float32)
for i in 1:n_batches
batch_loss, gs = Flux.withgradient(chain) do m
yhat = m(X[i])
pen = penalty(Flux.params(m)) / n_batches
loss(yhat, y[i]) + pen
end
training_loss += batch_loss
Flux.update!(opt_state, chain, gs[1])
end
return training_loss / n_batches
end the tests just seem to get stuck at some point. I may try and commit this now, but at least locally on my machine things get stuck. Alternatively, using the approach in FluxML/Flux.jl#2040 (comment) as follows function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
opt_state = Flux.setup(optimiser, chain)
loss = model.loss
n_batches = length(y)
training_loss = zero(Float32)
for i in 1:n_batches
batch_loss, gs = Flux.withgradient(chain) do m
yhat = m(X[i])
l = loss(yhat, y[i])
reg = Functors.fmap(penalty, m; exclude=Flux.trainable)
l + reg / n_batches
end
training_loss += batch_loss
Flux.update!(opt_state, chain, gs[1])
end
return training_loss / n_batches
end I get the following error: [ Info: acceleration = CPU1{Nothing}(nothing)
┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(5 => 15) # 90 parameters
│ summary(x) = "5×20 Matrix{Float64}"
└ @ Flux ~/.julia/packages/Flux/n3cOc/src/layers/stateless.jl:60
fit! and dropout (typename(CPU1)): Error During Test at /Users/patrickaltmeyer/code/MLJFlux.jl/test/test_utils.jl:38
Got exception outside of a @test
TypeError: non-boolean (NamedTuple{(:layers,), Tuple{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dropout{Float64, Colon, Random.TaskLocalRNG}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}) used in boolean context Perhaps it has to do with the fact that the penalizers aren't |
Yeah I wouldn't try the first version you have there, was referring to the second one or @mcabbott's suggestion about moving things to the optimization step.
Pretty sure that's due to a typo in the original example code snippet. See FluxML/Flux.jl#2040 (comment) |
hmm in that case I get the following error: |
Thanks @pat-alt for this work!
I don't know what the source of your current issue is. |
@pat-alt I don't think your use of Your first suggestion (with @ToucheSir To implement mixed L1/L2 penalties (not just L2 ones) I don't really see how to avoid the |
It's arguably better, but it requires some helper functionality that isn't currently nicely packaged up in a library. FluxML/Optimisers.jl#57 is one example of how to do this and how we're thinking about packaging it up going forwards, but the problem with general solutions is that they take time. For this work, you may be better served by implementing a similar but more constrained version on top of Functors.jl and Optimisers.jl which only includes as much as MLJFlux needs for regularization. If you do, feel free to ping me for input. |
@ToucheSir Thanks for the prompt response and offer of help. So, with the apparatus you describe (Functors.jl, etc ) what code replaces the following to avoid the # function to return penalty on an array:
f(A) = 0.01*sum(abs2, A) + 0.02*sum(abs, A)
f(ones(2,3))
# 0.6000000000000001
chain = Chain(Dense(3=>5), Dense(5=>1, relu))
penalty = sum(f.(Flux.params(chain))) |
Or if you prefer, how should the regularisation example in the Flux documentation be re-written (without the weight-decay trick , which does not work for L1 penalty)? |
f(A) = ...
penalty = mytotal(f, chain) Where
Side note, but I remembered looking into this a few months back and coming across https://stackoverflow.com/questions/42704283/l1-l2-regularization-in-pytorch/66630301#66630301, which suggests that L1 could be implemented using a similar trick. Whether that would be compatible with MLJFlux's API I'm not sure, but we could consider adding it to Optimisers.jl. |
Thanks for the help @ToucheSir . Unfortunately, I suggest we wait on the WeightDecay extension referenced above and switch that approach, which is likely more performant anyhow. |
It seems the style used here is being deprecated and won't work with Flux 0.14:
MLJFlux.jl/src/core.jl
Line 37 in 452c09d
edit After discussion below, I suggest we wait on
and refactor to use a optimiser-based solution to weight regularisation, which will avoid current limitations of explicit differentiation outlined in the discussion. Note, this will likely mean the reported
training_loss
must change, as it will no longer include the weight penalty. So this will be breaking.The text was updated successfully, but these errors were encountered: