Skip to content

Commit

Permalink
Merge pull request #609 from SciML/ChrisRackauckas-patch-1
Browse files Browse the repository at this point in the history
Make AutoZygote robust to zero gradients
  • Loading branch information
Vaibhavdixit02 authored Oct 12, 2023
2 parents 18d4468 + 72f75b6 commit f0c4446
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ function Optimization.instantiate_function(f, x, adtype::AutoZygote, p,
num_cons = 0)
_f = (θ, args...) -> f(θ, p, args...)[1]
if f.grad === nothing
grad = (res, θ, args...) -> false ?
false :
res .= Zygote.gradient(x -> _f(x, args...), θ)[1]
grad = function (res, θ, args...)
val = Zygote.gradient(x -> _f(x, args...), θ)[1]
if val === nothing
res .= 0
else
res .= val
end
end
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
end
Expand Down Expand Up @@ -83,9 +88,14 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,

_f = (θ, args...) -> f(θ, cache.p, args...)[1]
if f.grad === nothing
grad = (res, θ, args...) -> false ?
false :
res .= Zygote.gradient(x -> _f(x, args...), θ)[1]
grad = function (res, θ, args...)
val = Zygote.gradient(x -> _f(x, args...), θ)[1]
if val === nothing
res .= 0
else
res .= val
end
end
else
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
end
Expand Down

0 comments on commit f0c4446

Please sign in to comment.