Skip to content

Commit

Permalink
use _old_to_new in Optimisers.setup too
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 18, 2022
1 parent 5d74b04 commit 7eaf3ea
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,34 @@ Base.@deprecate_binding ADADelta AdaDelta
#=
# Valid method in Optimise, old implicit style, is:
train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
# Valid methods in Train, new explict style, are:
train!(loss, model, data, opt)
train!(loss, model, data, opt::Optimisers.AbstractRule)
train!(loss, model, data, opt) # preferred
train!(loss, model, data, opt::Optimisers.AbstractRule) # if you forget setup
# Provide friendly errors for what happens if you mix these up:
=#
import .Optimise: train!
train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state")

train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
train!(loss, ps::Params, data, opt) = error(
"""can't mix implict Params with explict state!
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
But better to use the new explicit style, in which `m` itself is the 2nd argument.
""")

train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error(
"""can't mix implict Params with explict rule from Optimisers.jl
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-mo$
But better to use the new explicit style, in which `m` itself is the 2nd argument.
""")

# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
# Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
# where `loss_mxy` accepts the model as its first argument.
# """
# ))
train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))

# Next, to use the new `setup` with the still-exported old-style Adam etc:
# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
import .Train: setup
setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
# ... and allow accidental use of `Optimisers.setup` to do the same:
Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)

for T in [:Descent, :Adam, :Momentum, :Nesterov,
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
Expand All @@ -129,10 +134,16 @@ _old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon

_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")

Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = error("please use Flux.setup not Optimisers.setup, it may be able to translate this rule")

# v0.14 deprecations

# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:
# Base.@deprecate_binding Optimiser OptimiserChain
# Base.@deprecate_binding ClipValue ClipGrad

# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
# Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
# where `loss_mxy` accepts the model as its first argument.
# """
# ))

0 comments on commit 7eaf3ea

Please sign in to comment.