-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Opt out of CR broadcasting #1263
Conversation
src/compiler/chainrules.jl
Outdated
@@ -18,6 +18,9 @@ such that if a suitable rule is defined later, the generated function will recom | |||
function has_chain_rrule(T) | |||
config_T, arg_Ts = Iterators.peel(T.parameters) | |||
configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...}) | |||
|
|||
isnothing(configured_rrule_m) && return false, nothing # too crude, surely |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something is wrong with the handling of methods produced by @opt_out
. Without this line, you get the following error:
julia> gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),) # first test from "tricky broadcasting" in features.jl
ERROR: type Nothing has no field method
Stacktrace:
[1] getproperty(x::Nothing, f::Symbol)
@ Base ./Base.jl:37
[2] has_chain_rrule(T::Type)
@ Zygote ~/.julia/dev/Zygote/src/compiler/chainrules.jl:24
[3] #s263#1180
@ ./compiler/interface2.jl:20 [inlined]
[4] var"#s263#1180"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[5] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:585
[6] _pullback
@ ./REPL[13]:1 [inlined]
[7] _pullback(ctx::Zygote.Context, f::var"#449#450", args::Tuple{Int64, Int64})
@ Zygote ./compiler/interface2.jl:0
[8] _pullback(f::Function, args::Tuple{Int64, Int64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:34
[9] pullback(f::Function, args::Tuple{Int64, Int64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:40
[10] gradient(f::Function, args::Tuple{Int64, Int64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:75
[11] top-level scope
@ REPL[13]:1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paging @mzgubic who I think knows about this story. What's the right way to handle this nothing
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I've forked your branch and commented out the line above, but can't reproduce this behaviour.
(opt_out_broadcasting) pkg> st
Status `~/JuliaEnvs/test/opt_out_broadcasting/Project.toml`
[082447d4] ChainRules v1.39.1
[d360d2e6] ChainRulesCore v1.15.3 `dev/ChainRulesCore`
[e88e6eb3] Zygote v0.6.41 `dev/Zygote`
julia> using Zygote
julia> using ChainRulesCore
julia> using ChainRules
julia> gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I see no error in this configuration. ChainRules v1.39.1 is the tagged version, which does not have a rule matching this PR's @opt_out
.
The error is when using this PR with JuliaDiff/ChainRules.jl#644 (now rebased to 1.39.1 too) and still happens for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, wow, missed the context entirely. In any case, I think what is happening is:
julia> config_T = Zygote.ZygoteRuleConfig{Zygote.Context};
julia> arg_Ts = Any[typeof(Base.Broadcast.broadcasted), typeof(+), Tuple{Int64, Int64}, Matrix{Float64}];
julia> configured_rrule_m = IRTools.meta(Tuple{typeof(rrule), config_T, arg_Ts...})
# should not be nothing
# and the reason is the method error here
julia> rrule(Zygote.ZygoteRuleConfig(), Base.Broadcast.broadcasted, +, (1, 2), rand(2, 2))
ERROR: MethodError: rrule(::Zygote.ZygoteRuleConfig{Zygote.Context}, ::typeof(Base.Broadcast.broadcasted), ::typeof(+), ::Tuple{Int64, Int64}, ::Matrix{Float64}) is ambiguous. Candidates:
rrule(::RuleConfig{>:HasReverseMode}, ::typeof(Base.Broadcast.broadcasted), ::typeof(+), xs::Union{Number, Base.Broadcast.Broadcasted, AbstractArray{<:Number}, Tuple{Vararg{Number, var"#s3204"}} where var"#s3204"}...) in ChainRules at /Users/mzgubic/JuliaEnvs/test/opt_out_broadcasting/dev/ChainRules/src/rulesets/Base/broadcast.jl:157
rrule(cfg::Zygote.ZygoteRuleConfig, ::typeof(Base.Broadcast.broadcasted), f::F, args::Vararg{Any, N}) where {F, N} in Zygote
Possible fix, define
rrule(::Zygote.ZygoteRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(+), ::Vararg{Union{Number, Base.Broadcast.Broadcasted, AbstractArray{<:Number}, Tuple{Vararg{Number, var"#s3204"}} where var"#s3204"}, N}) where N
Stacktrace:
[1] top-level scope
@ REPL[54]:1
perhaps we could throw a Method error directly? E.g. such as in https://github.com/JuliaDiff/ChainRulesTestUtils.jl/blob/6925da14c12e3d743c8d3620db8a8bee1433d5c3/src/testers.jl#L128
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I didn't realise there was an ambiguity. A nicer error would have helped me, but...
The ambiguity is this: I want to skip the "generic" CR rule, so I call this, and the new method is less specific than some other rrule
s for specific operations like +
:
julia> @macroexpand1 ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
quote
((ChainRulesCore).no_rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any, N}) where {F, N}) = ChainRulesCore.nothing
((ChainRulesCore).rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any, N}) where {F, N}) = ChainRulesCore.nothing
end
Those more specific rules aren't so different from Zygote's, so it may not matter whose rule is called. But ideally this PR should probably opt out of all of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this won't work. We can't opt out of e.g. the broadcasting rules NNlib defines. We need another plan here.
This reverts commit 10a2186.
This adds an
@opt_out
so as not to use the broadcasting rule defined in JuliaDiff/ChainRules.jl#644.Without this, the
rrule
forbroadcasted(f, args...)
is applied before the@adjoint
rule forbroadcasted(:: AbstractArrayStyle , f, args...)
. And then many tests here fail.To stop the new
rrule
being used on older Zygote versions, I think they will need an upper bound on CR added to the registry.