Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rrule(broadcasted, ...) with BroadcastStyle skips many rules #663

Closed
dfdx opened this issue Aug 18, 2022 · 5 comments
Closed

rrule(broadcasted, ...) with BroadcastStyle skips many rules #663

dfdx opened this issue Aug 18, 2022 · 5 comments
Labels

Comments

@dfdx
Copy link
Contributor

dfdx commented Aug 18, 2022

To account for #644 in Yota, I change all calls like this:

rrule(cfg, broadcasted, f, args...)

to this:

rrule(cfg, broadcasted, bcast_style, f, args...)

However, this way many rules are not triggered anymore. A particular example I encountered is activation functions from NNlib, e.g.:

julia> @which rrule(YotaRuleConfig(), broadcasted, leakyrelu, x, 0.2f0)
rrule(::RuleConfig, args...) in ChainRulesCore at /home/azbs/.julia/packages/ChainRulesCore/ctmSK/src/rules.jl:134
# ^ just invokes the same without the config

julia> @which rrule(broadcasted, leakyrelu, x, 0.2f0)
rrule(::typeof(Base.Broadcast.broadcasted), ::typeof(leakyrelu), x1::Union{AbstractArray{<:T}, T} where T<:Number, x2::Number) in NNlib at /home/azbs/.julia/packages/NNlib/0QnJJ/src/activations.jl:909
# ^ correctly points to rrule defined in NNlib

@which rrule(YotaRuleConfig(), broadcasted, Base.Broadcast.DefaultArrayStyle{2}(), leakyrelu, x, 0.2f0)
rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N} in ChainRules at /home/azbs/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:29
# ^ hits generic broadcasting, bypassing NNlib rrules

I think I can check whether rrule() without the BroadcastStyle exists before adding it to the call, but I was wondering if somebody else encountered this problem and how they solved it.

@mcabbott
Copy link
Member

My understanding of what ought to happen is that, in the absence of a rrule whose signature matches, the AD should keep tracing. Base will dispatch from broadcasted(f, x, 0.2, s) down to something like broadcasted(f, DefaultArrayStyle{2}(), x, 0.2, Ref(s)) at which point there is a matching rule.

This is how Zygote's own rules work at present. One reason to like this setup for rrule is that the generic rule needs the first ::RuleConfig argument, while simple rules (like those in NNlib) do not. The dispatch for rrule is that methods without the RuleConfig are called only if there is no matching method with it. So if the generic rule does not have BroadcastStyle, then all simple rules must have RuleConfig even though they don't need it.

Besides 3rd party rules like NNlib, at present Yota also doesn't call the simple rules for +,-,*,/ etc. These are lazy, to fuse the forward pass, but instead the generic split (i.e. unfused) rules are used:

julia> using Yota, ChainRules

julia> ENV["JULIA_DEBUG"] = ChainRules;

julia> grad(x -> sum(sin.(x .- 1)), [1,2,3.0])
┌ Debug: split broadcasting derivatives
│   f = - (generic function with 253 methods)
│   N = 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:72
┌ Debug: split broadcasting forwards
│   frule_fun = frule (generic function with 773 methods)
│   f = sin (generic function with 14 methods)
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:107
(1.7507684116335782, (ChainRulesCore.ZeroTangent(), [1.0, 0.5403023058681398, -0.4161468365471424]))

pkg> st Yota ChainRules
Status `~/.julia/environments/v1.9/Project.toml`
  [082447d4] ChainRules v1.44.2
  [cd998857] Yota v0.7.4

@dfdx
Copy link
Contributor Author

dfdx commented Aug 18, 2022

It makes sense, thank you. The reason Yota doesn't trace broadcasted is because it's defined in one of the built-in modules and thus is considered as primitive. Making it non-primitive is quite complicated in the current design but seems to be the most robust solution. I'll run one experiment and come back to you.

@mcabbott
Copy link
Member

Ok, sorry about the extra work, I never looked closely at how the tracing works.

I guess the method is one of these @less broadcasted(atan, [1,2,3], 4) (and a few keyword ones). They immediately call the method with the style, just one hop, and it looks like this hasn't changed in forever.

Maybe standalone .+ needs thought though, that's newer.

@mcabbott
Copy link
Member

One more case which may need to keep going is a call to broadcast spelled out:

julia> using Yota, ChainRules

julia> ENV["JULIA_DEBUG"] = ChainRules;

julia> Yota.grad(xs -> sum(abs.(xs)), [1,2,3.0])
┌ Debug: split broadcasting derivative
│   f = abs (generic function with 11 methods)
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:59
(6.0, (ChainRulesCore.ZeroTangent(), [1.0, 1.0, 1.0]))

julia> Yota.grad(xs -> sum(broadcast(abs, xs)), [1,2,3.0])
ERROR: No deriative rule found for op %3 = broadcast(abs, %2)::Vector{Float64} , ...

@mcabbott mcabbott added the Yota label Aug 20, 2022
@dfdx
Copy link
Contributor Author

dfdx commented Aug 22, 2022

Marking broadcast() and broadcasted() without the style argument as tracable resolved the issue, thanks for the tip!

Just for reference, here's I get now for the examples in this thread:

using Yota, ChainRules

julia> using Yota, ChainRules
[ Info: Precompiling Yota [cd998857-8626-517d-b929-70ad188a48f0]

julia> ENV["JULIA_DEBUG"] = ChainRules;

julia> grad(x -> sum(sin.(x .- 1)), [1,2,3.0])
┌ Debug: broadcasting: minus 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:170
┌ Debug: split broadcasting forwards
│   frule_fun = frule (generic function with 773 methods)
│   f = sin (generic function with 14 methods)
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:107
(1.7507684116335782, (ChainRulesCore.ZeroTangent(), [1.0, 0.5403023058681398, -0.4161468365471424]))

julia> Yota.grad(xs -> sum(abs.(xs)), [1,2,3.0])
┌ Debug: split broadcasting derivative
│   f = abs (generic function with 11 methods)
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:59
(6.0, (ChainRulesCore.ZeroTangent(), [1.0, 1.0, 1.0]))

julia> Yota.grad(xs -> sum(broadcast(abs, xs)), [1,2,3.0])
┌ Debug: split broadcasting derivative
│   f = abs (generic function with 11 methods)
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:59
(6.0, (ChainRulesCore.ZeroTangent(), [1.0, 1.0, 1.0]))

@dfdx dfdx closed this as completed Aug 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants