-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rrule(broadcasted, ...) with BroadcastStyle skips many rules #663
Comments
My understanding of what ought to happen is that, in the absence of a This is how Zygote's own rules work at present. One reason to like this setup for Besides 3rd party rules like NNlib, at present Yota also doesn't call the simple rules for +,-,*,/ etc. These are lazy, to fuse the forward pass, but instead the generic split (i.e. unfused) rules are used: julia> using Yota, ChainRules
julia> ENV["JULIA_DEBUG"] = ChainRules;
julia> grad(x -> sum(sin.(x .- 1)), [1,2,3.0])
┌ Debug: split broadcasting derivatives
│ f = - (generic function with 253 methods)
│ N = 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:72
┌ Debug: split broadcasting forwards
│ frule_fun = frule (generic function with 773 methods)
│ f = sin (generic function with 14 methods)
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:107
(1.7507684116335782, (ChainRulesCore.ZeroTangent(), [1.0, 0.5403023058681398, -0.4161468365471424]))
pkg> st Yota ChainRules
Status `~/.julia/environments/v1.9/Project.toml`
[082447d4] ChainRules v1.44.2
[cd998857] Yota v0.7.4 |
It makes sense, thank you. The reason Yota doesn't trace |
Ok, sorry about the extra work, I never looked closely at how the tracing works. I guess the method is one of these Maybe standalone |
One more case which may need to keep going is a call to julia> using Yota, ChainRules
julia> ENV["JULIA_DEBUG"] = ChainRules;
julia> Yota.grad(xs -> sum(abs.(xs)), [1,2,3.0])
┌ Debug: split broadcasting derivative
│ f = abs (generic function with 11 methods)
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:59
(6.0, (ChainRulesCore.ZeroTangent(), [1.0, 1.0, 1.0]))
julia> Yota.grad(xs -> sum(broadcast(abs, xs)), [1,2,3.0])
ERROR: No deriative rule found for op %3 = broadcast(abs, %2)::Vector{Float64} , ... |
Marking Just for reference, here's I get now for the examples in this thread: using Yota, ChainRules
julia> using Yota, ChainRules
[ Info: Precompiling Yota [cd998857-8626-517d-b929-70ad188a48f0]
julia> ENV["JULIA_DEBUG"] = ChainRules;
julia> grad(x -> sum(sin.(x .- 1)), [1,2,3.0])
┌ Debug: broadcasting: minus 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:170
┌ Debug: split broadcasting forwards
│ frule_fun = frule (generic function with 773 methods)
│ f = sin (generic function with 14 methods)
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:107
(1.7507684116335782, (ChainRulesCore.ZeroTangent(), [1.0, 0.5403023058681398, -0.4161468365471424]))
julia> Yota.grad(xs -> sum(abs.(xs)), [1,2,3.0])
┌ Debug: split broadcasting derivative
│ f = abs (generic function with 11 methods)
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:59
(6.0, (ChainRulesCore.ZeroTangent(), [1.0, 1.0, 1.0]))
julia> Yota.grad(xs -> sum(broadcast(abs, xs)), [1,2,3.0])
┌ Debug: split broadcasting derivative
│ f = abs (generic function with 11 methods)
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:59
(6.0, (ChainRulesCore.ZeroTangent(), [1.0, 1.0, 1.0])) |
To account for #644 in Yota, I change all calls like this:
to this:
However, this way many rules are not triggered anymore. A particular example I encountered is activation functions from NNlib, e.g.:
I think I can check whether
rrule()
without theBroadcastStyle
exists before adding it to the call, but I was wondering if somebody else encountered this problem and how they solved it.The text was updated successfully, but these errors were encountered: