Skip to content

Commit

Permalink
Add AD.value_and_gradient too
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack committed Oct 15, 2022
1 parent 8b8eebc commit 3fb5410
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/optimise/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ AD.@primitive pullback_function(ad::ZygoteImplicitBackend, f, x::Zygote.Params)
# this is a hack to get around
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150
AD.gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = Zygote.gradient(f, x)
AD.value_and_gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) =
Zygote.withgradient(f, x)

struct ZygoteExplicitBackend{T} <: AD.AbstractReverseMode
core_backend::T
Expand All @@ -21,3 +23,5 @@ AD.@primitive pullback_function(ad::ZygoteExplicitBackend, f, xs...) =
# this is a hack to get around
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150
AD.gradient(::ZygoteExplicitBackend, f, xs...) = Zygote.gradient(f, xs...)
AD.value_and_gradient(::ZygoteExplicitBackend, f, xs...) =
Zygote.withgradient(f, xs...)
2 changes: 1 addition & 1 deletion src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ _gradient_only(x) = error("Expected gradient w.r.t. single argument (or Zygote.G

"""
train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
Uses a `loss` function and training `data` to improve the
model's parameters according to a particular optimisation rule `opt`.
Expand Down

0 comments on commit 3fb5410

Please sign in to comment.