-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dictionary indexing failure inside closure and structs #717
Comments
@DhairyaLGandhi do you have any thoughts on what might be causing this? |
Here's another MWE. This one is a little more complex, because it matches a use case that I have.
module GradsMVP
using Zygote
mutable struct Foo
store::Dict{Symbol, Float64}
score::Float64
end
function (f::Foo)(acc::Symbol, fn::Function, args...)
val = getindex(f.store, acc)
ret = fn(val)
f.score += ret
fn(args...)
end
function get_grads(store, acc, ret_grad, call, args...)
fn = (args, store) -> begin
f = Foo(store, 0.0)
ret = f(acc, call, args...)
(f.score, ret)
end
_, back = Zygote.pullback(fn, args, store)
arg_grads, store_grads = back((1.0, ret_grad))
return arg_grads, store_grads
end
function foo(a::Float64)
return a
end
ags, gs = get_grads(Dict(:x => 1.0), :x, 1.0, foo, 1.0)
println(ags)
println(gs) # = nothing
end # module whereas this code works fine module GradsMVP
using Zygote
mutable struct Foo
store::Float64
score::Float64
end
function (f::Foo)(acc::Symbol, fn::Function, args...)
val = f.store
ret = fn(val)
f.score += ret
fn(args...)
end
function get_grads(store, ret_grad, call, args...)
fn = (args, store) -> begin
f = Foo(store, 0.0)
ret = f(call, args...)
(f.score, ret)
end
_, back = Zygote.pullback(fn, args, store)
arg_grads, store_grads = back((1.0, ret_grad))
return arg_grads, store_grads
end
function foo(a::Float64)
return a
end
ags, gs = get_grads(1.0, 1.0, foo, 1.0)
println(ags)
println(gs) # = 1.0
end # module |
To fix this MWE, it suffices to define the adjoint for Zygote.@adjoint getindex(d::Dict, acc) = getindex(d, acc), retgrad -> (retgrad, nothing) I'm unsure if this will break something fundamental. Edit: sorry, this is supposed to be |
@DhairyaLGandhi it's not Zygote's version of Something else is happening in the pipeline. |
PS This is fixed on |
Are you suggesting that the gradient is correctly calculated but isn't actually returned to the user properly? |
What's happening is entirely unclear to me. Since it's |
@DhairyaLGandhi when I print out |
@DhairyaLGandhi @willtebbutt any update on this? This is highly frustrating to me. I can't update to the latest version of Zygote, so I can't use the latest version of IRTools, so I can't use the latest version of Flux, which means I can't use neural networks in my PPs. I have no idea where this bug is occurring, but I'm motivated to find it and fix it - especially since it was fixed before in |
Setup a PR. I don't know what I'm doing, so I don't know if this fix breaks many other things - please inform. |
@willtebbutt @DhairyaLGandhi did this happened to get squashed in recent tags/PRs? |
Hmmm I'm not sure. @DhairyaLGandhi is more likely to know. |
This problem is still present julia> d = Dict("x"=>rand(2))
Dict{String, Vector{Float64}} with 1 entry:
"x" => [0.626974, 0.519716]
julia> gradient(x -> sum(x["x"]), d) #OK
(Dict{Any, Any}("x" => 2-element Fill{Float64}: entries equal to 1.0),)
julia> nt = (; data=rand(2))
(data = [0.7536687262661153, 0.34819635465370324],)
julia> gradient(x -> sum(x.data), nt) #OK
((data = 2-element Fill{Float64}: entries equal to 1.0,),)
julia> ntd = (; data = Dict("x" => rand(2)))
(data = Dict("x" => [0.6917549230112572, 0.16463696222948876]),)
julia> gradient(x -> sum(x.data["x"]), ntd) #WRONG
(nothing,) |
Came across this issue and I see all MWEs passing with #1248. If anyone still has a larger example to test, could you confirm it passes as well? Otherwise I'll consider this issue fixed if nothing pops up after a few days. |
closing as all examples are fixed. Will add tests |
the gradient w.r.t. the
y
element ofx
should be1
.This bug doesn't occur with the equivalent closure-free function
and appears to be
Dict
-specific sinceThis bug was introduced in 0.4.21 -- the correct result is obtained on 0.4.20. The bug persists on 0.4.22 and 0.5.
This is breaking for Stheno.jl.
@MikeInnes @CarloLucibello any thoughts on what might be causing this?
The text was updated successfully, but these errors were encountered: