From 037d6ef1c0ed7fd8a14bdd11e6b825041c8defe6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 23 Nov 2024 12:12:51 -0500 Subject: [PATCH] giant post-rebase fixup after everything was moved around... all earlier commits are a mess now, probably --- docs/src/reference/training/enzyme.md | 14 ++- ext/FluxEnzymeExt/FluxEnzymeExt.jl | 4 +- src/Flux.jl | 49 ++------- src/deprecations.jl | 148 +++----------------------- src/losses/Losses.jl | 4 +- src/optimise/train.jl | 3 - src/train.jl | 10 +- test/ext_enzyme/enzyme.jl | 6 +- 8 files changed, 46 insertions(+), 192 deletions(-) diff --git a/docs/src/reference/training/enzyme.md b/docs/src/reference/training/enzyme.md index 68ba3d5e8a..dc0879967f 100644 --- a/docs/src/reference/training/enzyme.md +++ b/docs/src/reference/training/enzyme.md @@ -35,7 +35,8 @@ julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const -0.0014538406], σ = nothing), nothing),), nothing, nothing) ``` -The gradient returned here is also stored within `dup_model`, it shares the same arrays. +The gradient returned here is also stored within `dup_model`. +Both share the same arrays -- what is returned is not a copy, just a view of the same memory (wrapped in `NamedTuple`s instead of `struct`s). They will all be set to zero when you call `gradient` again, then replaced with the new values. Alternatively, `gradient(f, args...; zero=false)` will add the new gradient to what's already stored. @@ -81,8 +82,19 @@ julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_s ## Listing +Flux functions: + ```@docs Flux.gradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...) Flux.withgradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...) Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt) ``` + +EnzymeCore types: + +```@docs +Flux.EnzymeCore.Duplicated +Flux.EnzymeCore.Const +``` + +Enzyme.jl has [its own extensive documentation](https://enzymead.github.io/Enzyme.jl/stable/). diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl index cf1bb9da3d..093557a254 100644 --- a/ext/FluxEnzymeExt/FluxEnzymeExt.jl +++ b/ext/FluxEnzymeExt/FluxEnzymeExt.jl @@ -2,9 +2,7 @@ module FluxEnzymeExt using Flux using Flux: _make_zero! - -import Flux.Train: _enzyme_train!, _rule_to_state, _grad_or_nothing -# import Flux.Optimise +import Flux.Train: _enzyme_train! import Optimisers import Functors diff --git a/src/Flux.jl b/src/Flux.jl index a9531440c8..f6c93100b0 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -9,6 +9,7 @@ using MacroTools: @forward @reexport using NNlib using MLUtils +using Adapt, Functors, OneHotArrays using Optimisers: Optimisers, destructure, freeze!, thaw!, adjust!, trainables, update! import Optimisers: trainable @@ -60,43 +61,9 @@ export Chain, Dense, Embedding, EmbeddingBag, destructure, freeze!, thaw!, adjust!, trainables, update!, trainable, # from Functors.jl functor, @functor, KeyPath, haskeypath, getkeypath, - # from Optimise/Train/Optimisers.jl - setup, update!, destructure, freeze!, adjust!, params, trainable, trainables -)) - -# Pirate error to catch a common mistake. -Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.") - -include("layers/show.jl") -include("layers/macro.jl") - -include("layers/stateless.jl") -include("layers/basic.jl") -include("layers/conv.jl") -include("layers/recurrent.jl") -include("layers/normalise.jl") -include("layers/upsample.jl") -include("layers/attention.jl") - -include("loading.jl") - -include("outputsize.jl") -export @autosize - -include("deprecations.jl") - -include("losses/Losses.jl") -using .Losses - -include("devices.jl") -export get_device, gpu_backend! - -# Distributed Training -include("distributed/backend.jl") -include("distributed/public_api.jl") -export MPIBackend, NCCLBackend, DistributedUtils - -@compat(public, ( + # from Train/Optimisers.jl + setup, update!, destructure, freeze!, adjust!, params, trainable, trainables, + withgradient, # init glorot_uniform, glorot_normal, @@ -128,15 +95,13 @@ export MPIBackend, NCCLBackend, DistributedUtils tversky_loss, )) +include("gradient.jl") +export gradient + include("train.jl") using .Train using .Train: setup -include("gradient.jl") -export gradient -@compat(public, (withgradient,)) - -using Adapt, Functors, OneHotArrays include("utils.jl") include("functor.jl") diff --git a/src/deprecations.jl b/src/deprecations.jl index 52a5a1a41e..e726ea4286 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -18,125 +18,6 @@ GRUv3Cell(in::Integer, out::Integer; kw...) = GRUv3Cell(in => out; kw...) #### v0.14 deprecations ########################### -<<<<<<< HEAD -======= - # Valid methods in Train, new explict style, are: - train!(loss, model, data, opt) # preferred - train!(loss, model, data, opt::Optimisers.AbstractRule) # if you forget setup - - # Provide friendly errors for what happens if you mix these up: -=# -import .Optimise: train! - -train!(loss, ps::Params, data, opt; cb=nothing) = error( - """can't mix implict Params with explict state! - To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module. - But better to use the new explicit style, in which `m` itself is the 2nd argument. - """) - -train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error( - """can't mix implict Params with explict rule from Optimisers.jl - To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module. - But better to use the new explicit style, in which `m` itself is the 2nd argument. - """) - -train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = - train!(loss, model, data, __old_to_new(opt); cb) - -# Next, to use the new `setup` with the still-exported old-style `Adam` etc: -import .Train: setup -setup(rule::Optimise.AbstractOptimiser, model) = setup(__old_to_new(rule), model) -# ... and allow accidental use of `Optimisers.setup` to do the same: -Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(__old_to_new(rule), model) - - -function __old_to_new(rule) - Base.depwarn("""Optimisers from Flux.Optimise module are deprecated. - Use optimisers from Optimisers.jl instead.""", :__old_to_new) - return _old_to_new(rule) -end - -for T in [:Descent, :Adam, :Momentum, :Nesterov, - :AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief, - # :InvDecay, :ExpDecay, - :SignDecay, - ] - @eval function _old_to_new(rule::Optimise.$T) - args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T)) - Optimisers.$T(args...) - end -end - -_old_to_new(rule::Optimise.Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...) -# const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too. -const Optimiser = Optimisers.OptimiserChain -_old_to_new(rule::Optimise.WeightDecay) = Optimisers.WeightDecay(rule.wd) # called lambda now -_old_to_new(rule::Optimise.ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields -_old_to_new(rule::Optimise.ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs -# const ClipGrad = Optimise.ClipValue -const ClipValue = Optimisers.ClipGrad -_old_to_new(rule::Optimise.RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred - -_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule") - -# This allows you to mix and match, like Flux.setup(OptimiserChain(Optimisers.SignDecay(), Flux.Descent()), [1,2,3.]) -Optimisers.OptimiserChain(rules::Union{Optimisers.AbstractRule, Optimise.AbstractOptimiser}...) = - Optimisers.OptimiserChain(map(_old_to_new, rules)) -_old_to_new(rule::Optimisers.AbstractRule) = rule - -# Since `update!` should be called in a loop, it makes less sense to call `setup` for you if you forgot. -# But let's make sure that such uses give a helpful error: -import .Optimise: update! - -function update!(opt::Optimise.AbstractOptimiser, model, grad) - # This error method requires narrowing the main worker method of Flux.Optimise - # to accept only arrays. Remove if this causes problems! - # update!(opt::Flux.Optimise.AbstractOptimiser, x::AbstractArray, x̄) - error("""Invalid input to `update!`. - * For the implicit style, this needs `update!(::AbstractOptimiser, ::Params, ::Grads)` - * For the explicit style, `update!(state, model, grad)` needs `state = Flux.setup(opt, model)`. - """) -end - -# TODO this friendly error should go in Optimisers.jl. -# remove after https://github.com/FluxML/Optimisers.jl/pull/181 -function update!(opt::Optimisers.AbstractRule, model, grad) - error("""Invalid input to `update!`. - `update!(state, model, grad)` needs `state = Flux.setup(opt, model)`. - """) -end -function update!(opt::Optimisers.AbstractRule, model::Chain, grad::Tuple) - error("""Invalid input to `update!`. - `update!(state, model, grad)` needs `state = Flux.setup(opt, model)`. - """) -end - -# An easy error to make is to pass result of explicit gradient(...), not gradient(...)[1] -# Can't catch every case, but can catch many simple Flux models: - -function update!(opt, model::Chain, grads::Tuple) - # Zygote will make a NamedTuple{(:layers,)} for the gradient of Chain, Diffractor a Tangent - @warn """explicit `update!(opt, model, grad)` wants the gradient for the model alone, - not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`.""" - update!(opt, model, grads[1]) -end - -function update!(opt::Optimise.AbstractOptimiser, model::Chain, grads::Tuple) # ambiguity - update!(opt, model, grads[1]) # calls error case "Invalid input" just above -end - -# One more easy error to catch is using explicit gradient with `params(m)`: - -function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple, NamedTuple}) - error("""can't mix implicit Params with explicit gradients! - * For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)` with implicit gradient. - * For the explicit style, `update(state, model, grad)` needs the model itself, and `state = Flux.setup(opt, model)`. - """) -end - - -# v0.14 deprecations ->>>>>>> 9576dba8 (add more Duplicated methods) @deprecate default_rng_value() Random.default_rng() @@ -184,21 +65,6 @@ const FluxMetalAdaptor = MetalDevice ######## v0.15 deprecations ######################### # Enable these when 0.16 is released, and delete const ClipGrad = Optimise.ClipValue etc: -function gradient(f, p::Zygote.Params) - Base.depwarn("""Implicit gradients such as `gradient(f, ::Params)` are deprecated in Flux! - Please see the docs for new explicit form.""", :gradient; force=true) - Zygote.gradient(f, p) -end - -function withgradient(f, p::Zygote.Params) - Base.depwarn("""Implicit gradients such as `withgradient(f, ::Params)` are deprecated in Flux! - Please see the docs for new explicit form.""", :withgradient; force=true) - Zygote.withgradient(f, p) -end - -# v0.15 deprecations - -# Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc: # Base.@deprecate_binding Optimiser OptimiserChain # Base.@deprecate_binding ClipValue ClipGrad @@ -255,8 +121,22 @@ function Optimisers.update!(opt::Optimisers.AbstractRule, model, grad) `update!(state, model, grad)` needs `state = Flux.setup(opt, model)`. """) end + +# This exists to solve an ambiguity between the method above & one in layers/basic.jl function Optimisers.update!(opt::Optimisers.AbstractRule, model::Chain, grad::Tuple) error("""Invalid input to `update!`. `update!(state, model, grad)` needs `state = Flux.setup(opt, model)`. """) end + +# From 0.15, Flux.gradient is not Zygote.gradient, but we can add a deprecation path: +function gradient(f, p::Zygote.Params) + Base.depwarn("""Implicit gradients such as `gradient(f, ::Params)` are deprecated in Flux! + Please see the docs for new explicit form.""", :gradient; force=true) + Zygote.gradient(f, p) +end +function withgradient(f, p::Zygote.Params) + Base.depwarn("""Implicit gradients such as `withgradient(f, ::Params)` are deprecated in Flux! + Please see the docs for new explicit form.""", :withgradient; force=true) + Zygote.withgradient(f, p) +end diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 5b4a1d697b..ec5f7ae360 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -4,7 +4,9 @@ using Statistics using Zygote using Zygote: @adjoint using ChainRulesCore -using ..Flux: ofeltype, epseltype +# using ..Flux: ofeltype, epseltype +ofeltype(x, y) = convert(float(eltype(x)), y) +epseltype(x) = eps(float(eltype(x))) using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss import Base.Broadcast: broadcasted diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 58f1098ad3..af1ac86032 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,5 +1,3 @@ -<<<<<<< HEAD -======= using ProgressLogging: @progress, @withprogress, @logprogress import Zygote: Params, gradient, withgradient @@ -21,7 +19,6 @@ The gradient could be mutated as well. This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.15. The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain. """ ->>>>>>> 1466ba36 (let Flux own the function update! to avoid piracy) function update!(opt::AbstractOptimiser, x::AbstractArray, x̄) x̄r = copyto!(similar(x̄), x̄) # Flux.Optimise assumes it can mutate the gradient. This is not # safe due to aliasing, nor guaranteed to be possible, e.g. Fill. diff --git a/src/train.jl b/src/train.jl index 3b9992e983..930e1b0d34 100644 --- a/src/train.jl +++ b/src/train.jl @@ -8,7 +8,7 @@ using ..Flux: Flux using ProgressLogging: @progress, @withprogress, @logprogress using Zygote: Zygote -import ..Flux.Optimise: train!, update!, Optimise # during 0.13, we add methods to the old functions +# import ..Flux.Optimise: train!, update!, Optimise # during 0.13, we add methods to the old functions export setup, train! @@ -163,10 +163,10 @@ train!(loss, model::Duplicated, data, opt; cb = nothing) = _enzyme_train!(loss, # FluxEnzymeExt defines more specific _enzyme_train!(loss, model::Duplicated, data, opt; cb) _enzyme_train!(loss, model, data, opt; cb = nothing) = throw(ArgumentError("The method `train!(loss, Duplicated(model), data, opt_state)` is only available when Enzyme.jl is loaded")) -# Following src/deprecations.jl -function train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) - train!(loss, model, data, _old_to_new(opt); cb) -end +# # Following src/deprecations.jl +# function train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) +# train!(loss, model, data, _old_to_new(opt); cb) +# end # This method let you use Optimisers.Descent() without setup, when there is no state function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb=nothing) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 6d95fb299d..98c8c7040e 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -1,7 +1,7 @@ using Test using Flux -using Enzyme: Enzyme, make_zero, Active, Duplicated, ReverseWithPrimal +using Enzyme: Enzyme, make_zero, Active, Duplicated, Const, ReverseWithPrimal using Functors using FiniteDifferences @@ -112,8 +112,8 @@ end ] for (model, x, name) in models_xs - @testset "check grad $name" begin - println("testing $name") + @testset "Enzyme grad check $name" begin + println("testing $name with Enzyme") test_enzyme_grad(loss, model, x) end end