diff --git a/src/optimise/gradients.jl b/src/optimise/gradients.jl index 9a3dba8f12..d92e55dc93 100644 --- a/src/optimise/gradients.jl +++ b/src/optimise/gradients.jl @@ -9,6 +9,8 @@ AD.@primitive pullback_function(ad::ZygoteImplicitBackend, f, x::Zygote.Params) # this is a hack to get around # https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 AD.gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = Zygote.gradient(f, x) +AD.value_and_gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = + Zygote.withgradient(f, x) struct ZygoteExplicitBackend{T} <: AD.AbstractReverseMode core_backend::T @@ -21,3 +23,5 @@ AD.@primitive pullback_function(ad::ZygoteExplicitBackend, f, xs...) = # this is a hack to get around # https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 AD.gradient(::ZygoteExplicitBackend, f, xs...) = Zygote.gradient(f, xs...) +AD.value_and_gradient(::ZygoteExplicitBackend, f, xs...) = + Zygote.withgradient(f, xs...) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index a422fa4f80..549d0fbe39 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -94,7 +94,7 @@ _gradient_only(x) = error("Expected gradient w.r.t. single argument (or Zygote.G """ train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb]) - + Uses a `loss` function and training `data` to improve the model's parameters according to a particular optimisation rule `opt`.