From 72f75b698d6f575f60b89768abcf1902bcaddfd6 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 12 Oct 2023 15:33:12 +0200 Subject: [PATCH] Make AutoZygote robust to zero gradients Zygote returns `nothing` for a true zero. --- ext/OptimizationZygoteExt.jl | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index f6e5cdd0d..250145f0d 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -10,9 +10,14 @@ function Optimization.instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0) _f = (θ, args...) -> f(θ, p, args...)[1] if f.grad === nothing - grad = (res, θ, args...) -> false ? - false : - res .= Zygote.gradient(x -> _f(x, args...), θ)[1] + grad = function (res, θ, args...) + val = Zygote.gradient(x -> _f(x, args...), θ)[1] + if val === nothing + res .= 0 + else + res .= val + end + end else grad = (G, θ, args...) -> f.grad(G, θ, p, args...) end @@ -83,9 +88,14 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, _f = (θ, args...) -> f(θ, cache.p, args...)[1] if f.grad === nothing - grad = (res, θ, args...) -> false ? - false : - res .= Zygote.gradient(x -> _f(x, args...), θ)[1] + grad = function (res, θ, args...) + val = Zygote.gradient(x -> _f(x, args...), θ)[1] + if val === nothing + res .= 0 + else + res .= val + end + end else grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) end