Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AD backends and explicit optimizers #2083

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.13.7"
version = "0.13.8-DEV"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -26,6 +27,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractDifferentiation = "0.4.3"
Adapt = "3.0"
ArrayInterface = "3.1, 4, 5, 6"
CUDA = "3"
Expand All @@ -51,6 +53,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker"]
9 changes: 9 additions & 0 deletions src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
module Optimise

using Flux
using MacroTools: @forward
import Zygote
import Zygote: Params, gradient
using AbstractDifferentiation
import Optimisers
import Optimisers: update!
using LinearAlgebra
import ArrayInterface
using ProgressLogging: @progress, @withprogress, @logprogress

export train!, update!,
Descent, Adam, Momentum, Nesterov, RMSProp,
Expand All @@ -10,6 +18,7 @@ export train!, update!,
ClipValue, ClipNorm

include("optimisers.jl")
include("gradients.jl")
include("train.jl")

end
27 changes: 27 additions & 0 deletions src/optimise/gradients.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
struct ZygoteImplicitBackend{T} <: AD.AbstractReverseMode
core_backend::T
end
ZygoteImplicitBackend() = ZygoteImplicitBackend(AD.ZygoteBackend())

AD.@primitive pullback_function(ad::ZygoteImplicitBackend, f, x::Zygote.Params) =
AD.pullback_function(ad.core_backend, f, x)

# this is a hack to get around
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150
AD.gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = Zygote.gradient(f, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be value_and_gradient to support changes like #2070?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, because it runs into the issue you mentioned in the link above the code. I could define both gradient and value_and_gradient to essentially block out AbstractDifferentiation until they sort out the primitives issues.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, it might make sense to have Flux.gradient and Flux.withgradient that defaults to AD.gradient and AD.value_and_gradient. Right now, Flux.gradient(f, xs...) wouldn't default to ZygoteImplicitBackend. Defining our own method would allow us to do this.

AD.value_and_gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) =
Zygote.withgradient(f, x)

struct ZygoteExplicitBackend{T} <: AD.AbstractReverseMode
core_backend::T
end
ZygoteExplicitBackend() = ZygoteExplicitBackend(AD.ZygoteBackend())

AD.@primitive pullback_function(ad::ZygoteExplicitBackend, f, xs...) =
AD.pullback_function(ad.core_backend, f, xs...)

# this is a hack to get around
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150
AD.gradient(::ZygoteExplicitBackend, f, xs...) = Zygote.gradient(f, xs...)
AD.value_and_gradient(::ZygoteExplicitBackend, f, xs...) =
Zygote.withgradient(f, xs...)
3 changes: 0 additions & 3 deletions src/optimise/optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
using Flux
using MacroTools: @forward

abstract type AbstractOptimiser end

const EPS = 1e-8
Expand Down
37 changes: 25 additions & 12 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
using ProgressLogging: @progress, @withprogress, @logprogress
import Zygote: Params, gradient, withgradient


"""
update!(opt, p, g)
update!(opt, ps::Params, gs)
Expand All @@ -12,17 +8,21 @@ according to optimizer `opt` and the gradients `gs` (the gradient `g`).
As a result, the parameters are mutated and the optimizer's internal state may change.
The gradient could be mutated as well.
"""
function update!(opt::AbstractOptimiser, x, x̄)
function Optimisers.update!(opt::AbstractOptimiser, x, x̄)
x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's
# output are not mutable, see #1510
x .-= apply!(opt, x, x̄r)

return opt, x
end

function update!(opt::AbstractOptimiser, xs::Params, gs)
function Optimisers.update!(opt::AbstractOptimiser, xs::Params, gs)
for x in xs
isnothing(gs[x]) && continue
update!(opt, x, gs[x])
end

return opt, xs
end

# Callback niceties
Expand Down Expand Up @@ -82,9 +82,19 @@ end
batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x

_build_loss(::AD.AbstractBackend, loss, data) = function _loss(m)
return loss(m, data...)
end
_build_loss(::ZygoteImplicitBackend, loss, data) = function _loss()
return loss(data...)
end
_gradient_only(x::Zygote.Grads) = x
_gradient_only(x::NTuple{1}) = x[1]
_gradient_only(x) = error("Expected gradient w.r.t. single argument (or Zygote.Grads) but got $x")

"""
train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])

Uses a `loss` function and training `data` to improve the
model's parameters according to a particular optimisation rule `opt`.

Expand Down Expand Up @@ -122,19 +132,18 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop.

Multiple callbacks can be passed to `cb` as array.
"""
function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
function train!(loss, ad::AD.AbstractBackend, model, data, optstate; cb = () -> ())
cb = runall(cb)
itrsz = Base.IteratorSize(typeof(data))
n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0
@withprogress for (i, d) in enumerate(data)
try
l, gs = withgradient(ps) do
loss(batchmemaybe(d)...)
end
_loss = _build_loss(ad, loss, batchmemaybe(d))
l, gs = AD.value_and_gradient(ad, _loss, model)
if !isfinite(l)
throw(DomainError("Loss is $l on data item $i, stopping training"))
end
update!(opt, ps, gs)
optstate, model = update!(optstate, model, _gradient_only(gs))
cb()
catch ex
if ex isa StopException
Expand All @@ -147,7 +156,11 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
end
@logprogress iszero(n) ? nothing : i / n
end

return optstate, model
end
train!(loss, model, data, optstate; kwargs...) =
train!(loss, ZygoteImplicitBackend(), model, data, optstate; kwargs...)

"""
@epochs N body
Expand Down
32 changes: 31 additions & 1 deletion test/optimise.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Flux.Optimise
using Flux.Optimise: runall
using Flux.Optimise: runall, ZygoteImplicitBackend, ZygoteExplicitBackend
using Flux: Params, gradient
import FillArrays, ComponentArrays
using Test
Expand Down Expand Up @@ -45,6 +45,36 @@ end
end
end

@testset "AD backends" begin
# this is hack to make Tracker work
AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad
AD.value_and_gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...)

function _loss_and_model(::ZygoteImplicitBackend, loss, model)
return () -> loss(model), Flux.params(model)
end
_loss_and_model(ad, loss, model) = loss, model

function _check_gradient(::ZygoteImplicitBackend, model, grad)
return grad[model[1].weight] == 2 .* Flux.ones32(5, 10) &&
grad[model[2].weight] == 10 .* Flux.ones32(2, 5)
end
function _check_gradient(ad, model, grad)
return grad[1].layers[1].weight == 2 .* Flux.ones32(5, 10) &&
grad[1].layers[2].weight == 10 .* Flux.ones32(2, 5)
end

@testset for ad in [ZygoteImplicitBackend(), ZygoteExplicitBackend(), AD.TrackerBackend()]
model = Chain(Dense(Flux.ones32(5, 10), false), Dense(Flux.ones32(2, 5), false))
x = Flux.ones32(10)
_loss, _model = _loss_and_model(ad, m -> sum(m(x)), model)
val, grad = AD.value_and_gradient(ad, _loss, _model)
@test val == sum(model(x))
@test _check_gradient(ad, model, grad)
@test _check_gradient(ad, model, AD.gradient(ad, _loss, _model))
end
end

@testset "Training Loop" begin
i = 0
l = 1
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using Flux: params
using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle
import Tracker
using Zygote
using AbstractDifferentiation
using CUDA

Random.seed!(0)
Expand Down