-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote hangs when taking explicit gradients of NaiveGAFlux model #1243
Comments
Unfortunately I can't seem to terminate the program gracefully enough to get a stacktrace. Here is a WE (notice the absence of M) where I manually created the model layer by layer until the gradient calculation stalled. I haven't run this one for 8 hours though so maybe it is not enough. Packages: ]add NaiveNASflux#942bc90, [email protected], NaiveNASlib, ChainRulesCore, Flux, Functors That NaiveNASflux commit is from this PR where I have removed the Zygote adjoint which accidentally also masked the problem due to usage of rrule definitions: using NaiveNASflux, NaiveNASlib.Extend, Flux
import Functors
import ChainRulesCore
import ChainRulesCore: RuleConfig, HasReverseMode, rrule, rrule_via_ad, NoTangent
# Need to pirate NaiveNASlibs forward pass because it is not Zygote compatible (it uses get!)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(NaiveNASlib.output!), memo, v)
rrule_via_ad(config, output_rrule!, memo, v)
end
# This is just for logging and so we can return NoTangent instead of the computed gradient
function output_rrule!(args...) end
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(output_rrule!), memo, v)
res, back = rrule_via_ad(config, _output_rrule!, memo, v)
@info "Forward $(name(v))"
return res, function (d)
@info "Backward $(name(v))"
back(d)
# Uncomment the line below to prevent the stall
#return NoTangent(), NoTangent(), NoTangent()
end
end
# This is the actual Zygore compatible forward pass
function _output_rrule!(memo, v::AbstractVertex)
v in keys(memo) && return memo[v]
inpt = map(iv -> output_rrule!(memo, iv), inputs(v))
memo[v] = v(inpt...)
end Logging of forwards and backwards pass is not needed to trigger the stall, just remove them if they bother you. Model definition: function makemodel(;layerfun=identity)
iv = conv2dinputvertex("in", 3)
v1 = convvertex("v1", (5,5), iv, 8; layerfun)
v2 = bnvertex("v2", v1, selu; layerfun)
v3 = fluxvertex("v3", MeanPool((2,2)), v2; layerfun)
# Fork with 3 paths a, b and c
v4a1 = bnvertex("v4a1", v3, relu; layerfun)
v4a2 = convvertex("v4a2", (1, 7), v4a1, 8, selu; layerfun)
v4b1 = convvertex("v4b1", (3,3), v3, 8, relu; layerfun)
v4b2 = bnvertex("v4b2", v4b1; layerfun)
v4b3 = convvertex("v4b3", (7,3), v4b2, 256; layerfun)
v4c1 = convvertex("v4c1", (3,3), v3, 8, relu; layerfun)
v4c2 = bnvertex("v4c2", v4c1; layerfun)
v4c3 = convvertex("v4c3", (1,7), v4c2, 16, relu; layerfun)
v4c4 = bnvertex("v4c4", v4c3, relu; layerfun)
v5 = concat("v5", v4a2, v4b3, v4c4; layerfun)
v6 = fluxvertex("v6", MaxPool((2,2)), v5; layerfun)
v7 = convvertex("v7", (3,3), v6, 32; layerfun)
v8 = bnvertex("v8", v7, selu; layerfun)
v9 = convvertex("v9", (5,3), v8, 512, selu; layerfun)
v10 = bnvertex("v10", v9; layerfun)
v11 = fluxvertex("v11", Conv((2,2), nout(v10) => 512, relu; stride=2), v10; layerfun)
CompGraph(iv, v11)
end
function convvertex(name, ks, in, outsize, act=identity; layerfun)
fluxvertex(name, Conv(ks, nout(in) => outsize, act; pad=SamePad()), in; layerfun)
end
bnvertex(name, in, act=identity; layerfun) = fluxvertex(name, BatchNorm(nout(in), act), in; layerfun) I made some attempts at removing parts of the model to make it simpler, but nothing exhaustive. Here is one example of how to generate an overview of the model if needed: [name.(vertices(model)) layer.(vertices(model)) map(v -> name.(inputs(v)), vertices(model))] Some utilities for experiments: # Triggers issue #1111
mutable struct MutableWrapper{T}
wrapped::T
end
(m::MutableWrapper)(x...) = m.wrapped(x...)
NaiveNASflux.layertype(g::MutableWrapper) = NaiveNASflux.layertype(g.wrapped)
NaiveNASlib.nout(g::MutableWrapper) = nout(g.wrapped)
# Removes output vertices, making the CompGraph structure non-cyclic
stripoutputs(g::CompGraph) = Functors.fmap(identity, g; walk=stripoutputs)
stripoutputs(f, x) = Functors._default_walk(f, x)
stripoutputs(f, v::InputVertex) = Functors._default_walk(f, v)
stripoutputs(f, v::AbstractVertex) = stripoutputs(f, base(v))
stripoutputs(f, v::CompVertex) = Functors._default_walk(f, v) Phew, here is the experiment code. Logging output is omitted for brevity. Forward and Backwards are always printed for each vertex. # This terminates despite BatchNorm having NoTangent. Gradients are given for all layers except BatchNorms
gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel());
# This also terminates, despite the MutableWrapper erasing all gradients
gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel(layerfun=MutableWrapper) |> stripoutputs);
# But this stalls (unless you uncomment the return statement in the rrule definition)
gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel(layerfun=MutableWrapper)); It could be so that the outputs is a red herring, and the middle example only works because the structure is simpler overall. This works however: gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel(layerfun=reduce(∘, Iterators.repeated(MutableWrapper,10))) |> stripoutputs); Note the result is the same with implicit gradients for all examples above. |
This mutable example FluxML/Flux.jl#1986 (comment) can I think be simplified to this, and has been broken at least since Zygote v0.6.0: julia> gradient(x -> x[], Ref(1.0))
(Base.RefValue{Any}((x = 1.0,)),) # v0.6.0
((x = 1.0,),) # v0.6.41
julia> gradient(x -> x[1][], (Ref(1.0),))
(nothing,) That Line 108 in cb59b6c
Line 230 in cb59b6c
The I am not sure why it was ever there. I worry a little about introducing double-counting between the updated Ref version and the returned version. But in the examples I can invent, one seems to matter for implicit mode, the other for explicit. |
Nice bisect. With #1248: julia> gradient(x -> x[1][], (Ref(1.0),))
(((x = 1.0,),),) Edit: #1243 (comment) appears to terminate locally as well with the aforementioned PR. @DrChainsaw do you mind checking on your side? |
I see. That PR sounds like a safer way to get the same effect, as this line will (I think) mean it never returns nothing in Perhaps tests from my attempt might be worth borrowing: mcabbott@927ee27 |
@ToucheSir Will do, hopefully later tonight when I get some time off. Bit of an unrelated shower thought: Zygote ranks pretty high on my "I give up" codebases and I'm sure many others feel the same (I think you guys are heroic for making the effort to maintain it). Anyways, the concrete proposal is to add a strong wish for more comments in the contributors guideline. Something along the lines of: Zygote is a very complex project maintained by the Julia community as its original creator is no longer maintaining it. Please help us making it more maintainable by adding comments whenever you have figured out what the purpose of some (undocumented) part is. Even if you are not certain, a "I think the purpose of this code is..." type comment can often help immensely. For example, I suspect that you two have some insights into the purpose of |
I think the big problem is that the current group of maintainers also falls into this bucket. Speaking for myself, I don't want to even think about AD, let alone touch it. Unfortunately, I have to because it ends up causing issues further up the stack. For Zygote specifically, it's basically on life support unless we can get someone who understands the compiler well enough and is also motivated to revive it. Presently, bugs like #1236 just pile up and we're kind of helpless to do much about them. That's why, though I try to add some comments while creating new PRs, I am hesitant to go on a docs spree across the codebase. When there appears to be zero appetite from the broader community for helping with Zygote maintenance, every non-bugfix feels like another sunk cost. Perhaps others have a different perspective, but from my POV we've somehow arrived at a XKCD #2347 type scenario where the de-facto reverse mode AD is a sinking ship and there are no ready alternatives to pick up the slack. Stepping off my soapbox, internal documentation is now tracked in #1274. If anyone wants to take a shot, I am more than happy to prioritize reviewing PRs for this. Otherwise, at least there's a list now. |
This is my view as well. I was perhaps just a bit too careful when wording it to the point that the message became unclear.
Fully understandable. I tried to propose a somewhat milder and more distributed approach: Whenever someone spends more than five minutes to figure out some undocumented part of Zygote, instead of just fixing the issue and moving on, they add some comments describing their understanding of it, then over time the codebase might be a bit more approachable. About the actual issue: Strangely enough, the WE does not seem to hang anymore, despite using the exact same project and manifest. There is however this which maybe is a hint to what the hang is about: # "Hanging" example with all gradients being nothing
julia> gg = gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel(layerfun=MutableWrapper));
julia> @time @show gg[1];
# Takes about 373 seconds, then prints about 20 lines of output
373.533915 seconds (180.54 k allocations: 11.777 MiB, 99.99% compilation time)
# Non hanging example where we get gradients in a fresh session
julia> ggg = gradient(m -> sum(m(ones(Float32,32,32, 3,1))), makemodel() |> stripoutputs);
julia> @time @show ggg[1];
# Starts printing right away, but spends about 37 seconds printing numbers
37.013652 seconds (14.55 M allocations: 578.515 MiB, 0.34% gc time, 0.35% compilation time) Note that I started from one of the the innermost gradients ( Perhaps it is just the horrible nested I might have started working on the example when printing gradients and then added output supression without checking carefully that this didn't change the outcome. Anyways, with #1248 I get the exact same behaviour between the two gradients, so from what I can tell it solves the issue. |
Continuing from FluxML/Flux.jl#1986 (comment). @DrChainsaw are you able to capture a profile or at least a stacktrace mid-hang? I think that would be the easiest way to get started troubleshooting this and trying to put together a MWE.
The text was updated successfully, but these errors were encountered: