diff --git a/Project.toml b/Project.toml index a494b960fc..213cdefed7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,8 +1,9 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.7" +version = "0.13.8-DEV" [deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -26,6 +27,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +AbstractDifferentiation = "0.4.3" Adapt = "3.0" ArrayInterface = "3.1, 4, 5, 6" CUDA = "3" @@ -51,6 +53,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker"] diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index e691ce0170..53d2792c8f 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,15 @@ module Optimise +using Flux +using MacroTools: @forward +import Zygote +import Zygote: Params, gradient +using AbstractDifferentiation +import Optimisers +import Optimisers: update! using LinearAlgebra import ArrayInterface +using ProgressLogging: @progress, @withprogress, @logprogress export train!, update!, Descent, Adam, Momentum, Nesterov, RMSProp, @@ -10,6 +18,7 @@ export train!, update!, ClipValue, ClipNorm include("optimisers.jl") +include("gradients.jl") include("train.jl") end diff --git a/src/optimise/gradients.jl b/src/optimise/gradients.jl new file mode 100644 index 0000000000..d92e55dc93 --- /dev/null +++ b/src/optimise/gradients.jl @@ -0,0 +1,27 @@ +struct ZygoteImplicitBackend{T} <: AD.AbstractReverseMode + core_backend::T +end +ZygoteImplicitBackend() = ZygoteImplicitBackend(AD.ZygoteBackend()) + +AD.@primitive pullback_function(ad::ZygoteImplicitBackend, f, x::Zygote.Params) = + AD.pullback_function(ad.core_backend, f, x) + +# this is a hack to get around +# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 +AD.gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = Zygote.gradient(f, x) +AD.value_and_gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = + Zygote.withgradient(f, x) + +struct ZygoteExplicitBackend{T} <: AD.AbstractReverseMode + core_backend::T +end +ZygoteExplicitBackend() = ZygoteExplicitBackend(AD.ZygoteBackend()) + +AD.@primitive pullback_function(ad::ZygoteExplicitBackend, f, xs...) = + AD.pullback_function(ad.core_backend, f, xs...) + +# this is a hack to get around +# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 +AD.gradient(::ZygoteExplicitBackend, f, xs...) = Zygote.gradient(f, xs...) +AD.value_and_gradient(::ZygoteExplicitBackend, f, xs...) = + Zygote.withgradient(f, xs...) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index ce72a4b0ce..bdb50fd6d5 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,6 +1,3 @@ -using Flux -using MacroTools: @forward - abstract type AbstractOptimiser end const EPS = 1e-8 diff --git a/src/optimise/train.jl b/src/optimise/train.jl index a1c3e9a7aa..2ea422e035 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,7 +1,3 @@ -using ProgressLogging: @progress, @withprogress, @logprogress -import Zygote: Params, gradient, withgradient - - """ update!(opt, p, g) update!(opt, ps::Params, gs) @@ -12,17 +8,21 @@ according to optimizer `opt` and the gradients `gs` (the gradient `g`). As a result, the parameters are mutated and the optimizer's internal state may change. The gradient could be mutated as well. """ -function update!(opt::AbstractOptimiser, x, x̄) +function Optimisers.update!(opt::AbstractOptimiser, x, x̄) x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's # output are not mutable, see #1510 x .-= apply!(opt, x, x̄r) + + return opt, x end -function update!(opt::AbstractOptimiser, xs::Params, gs) +function Optimisers.update!(opt::AbstractOptimiser, xs::Params, gs) for x in xs isnothing(gs[x]) && continue update!(opt, x, gs[x]) end + + return opt, xs end # Callback niceties @@ -82,9 +82,19 @@ end batchmemaybe(x) = tuple(x) batchmemaybe(x::Tuple) = x +_build_loss(::AD.AbstractBackend, loss, data) = function _loss(m) + return loss(m, data...) +end +_build_loss(::ZygoteImplicitBackend, loss, data) = function _loss() + return loss(data...) +end +_gradient_only(x::Zygote.Grads) = x +_gradient_only(x::NTuple{1}) = x[1] +_gradient_only(x) = error("Expected gradient w.r.t. single argument (or Zygote.Grads) but got $x") + """ train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb]) - + Uses a `loss` function and training `data` to improve the model's parameters according to a particular optimisation rule `opt`. @@ -122,19 +132,18 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop. Multiple callbacks can be passed to `cb` as array. """ -function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) +function train!(loss, ad::AD.AbstractBackend, model, data, optstate; cb = () -> ()) cb = runall(cb) itrsz = Base.IteratorSize(typeof(data)) n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0 @withprogress for (i, d) in enumerate(data) try - l, gs = withgradient(ps) do - loss(batchmemaybe(d)...) - end + _loss = _build_loss(ad, loss, batchmemaybe(d)) + l, gs = AD.value_and_gradient(ad, _loss, model) if !isfinite(l) throw(DomainError("Loss is $l on data item $i, stopping training")) end - update!(opt, ps, gs) + optstate, model = update!(optstate, model, _gradient_only(gs)) cb() catch ex if ex isa StopException @@ -147,7 +156,11 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) end @logprogress iszero(n) ? nothing : i / n end + + return optstate, model end +train!(loss, model, data, optstate; kwargs...) = + train!(loss, ZygoteImplicitBackend(), model, data, optstate; kwargs...) """ @epochs N body diff --git a/test/optimise.jl b/test/optimise.jl index 41de5a4a10..b59f67fd13 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -1,5 +1,5 @@ using Flux.Optimise -using Flux.Optimise: runall +using Flux.Optimise: runall, ZygoteImplicitBackend, ZygoteExplicitBackend using Flux: Params, gradient import FillArrays, ComponentArrays using Test @@ -45,6 +45,36 @@ end end end +@testset "AD backends" begin + # this is hack to make Tracker work + AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad + AD.value_and_gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...) + + function _loss_and_model(::ZygoteImplicitBackend, loss, model) + return () -> loss(model), Flux.params(model) + end + _loss_and_model(ad, loss, model) = loss, model + + function _check_gradient(::ZygoteImplicitBackend, model, grad) + return grad[model[1].weight] == 2 .* Flux.ones32(5, 10) && + grad[model[2].weight] == 10 .* Flux.ones32(2, 5) + end + function _check_gradient(ad, model, grad) + return grad[1].layers[1].weight == 2 .* Flux.ones32(5, 10) && + grad[1].layers[2].weight == 10 .* Flux.ones32(2, 5) + end + + @testset for ad in [ZygoteImplicitBackend(), ZygoteExplicitBackend(), AD.TrackerBackend()] + model = Chain(Dense(Flux.ones32(5, 10), false), Dense(Flux.ones32(2, 5), false)) + x = Flux.ones32(10) + _loss, _model = _loss_and_model(ad, m -> sum(m(x)), model) + val, grad = AD.value_and_gradient(ad, _loss, _model) + @test val == sum(model(x)) + @test _check_gradient(ad, model, grad) + @test _check_gradient(ad, model, AD.gradient(ad, _loss, _model)) + end +end + @testset "Training Loop" begin i = 0 l = 1 diff --git a/test/runtests.jl b/test/runtests.jl index 9027b114fc..ed04582b32 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,9 @@ using Flux: params using Test using Random, Statistics, LinearAlgebra using IterTools: ncycle +import Tracker using Zygote +using AbstractDifferentiation using CUDA Random.seed!(0)