From b8c6192bd4155d35a5d98b89752bb6429eb35557 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 10 Nov 2022 08:22:39 -0500 Subject: [PATCH] use _old_to_new in Optimisers.setup too --- src/deprecations.jl | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/src/deprecations.jl b/src/deprecations.jl index 782efde473..b686e68f99 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -86,29 +86,34 @@ Base.@deprecate_binding ADADelta AdaDelta #= # Valid method in Optimise, old implicit style, is: train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) + # Valid methods in Train, new explict style, are: - train!(loss, model, data, opt) - train!(loss, model, data, opt::Optimisers.AbstractRule) + train!(loss, model, data, opt) # preferred + train!(loss, model, data, opt::Optimisers.AbstractRule) # if you forget setup + # Provide friendly errors for what happens if you mix these up: =# import .Optimise: train! -train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state") -train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule") +train!(loss, ps::Params, data, opt) = error( + """can't mix implict Params with explict state! + To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module. + But better to use the new explicit style, in which `m` itself is the 2nd argument. + """) -train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt)) +train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error( + """can't mix implict Params with explict rule from Optimisers.jl + To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-mo$ + But better to use the new explicit style, in which `m` itself is the 2nd argument. + """) -# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( -# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`. -# Instead of `train!(loss_xy, Flux.params(model), data, Adam())` -# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)` -# where `loss_mxy` accepts the model as its first argument. -# """ -# )) +train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt)) -# Next, to use the new `setup` with the still-exported old-style Adam etc: +# Next, to use the new `setup` with the still-exported old-style `Adam` etc: import .Train: setup setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model) +# ... and allow accidental use of `Optimisers.setup` to do the same: +Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model) for T in [:Descent, :Adam, :Momentum, :Nesterov, :AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief, @@ -129,10 +134,16 @@ _old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon _old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule") -Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = error("please use Flux.setup not Optimisers.setup, it may be able to translate this rule") - # v0.14 deprecations # Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc: # Base.@deprecate_binding Optimiser OptimiserChain # Base.@deprecate_binding ClipValue ClipGrad + +# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( +# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`. +# Instead of `train!(loss_xy, Flux.params(model), data, Adam())` +# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)` +# where `loss_mxy` accepts the model as its first argument. +# """ +# ))