Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opt out of CR broadcasting #1263

Closed
wants to merge 5 commits into from
Closed

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jul 14, 2022

This adds an @opt_out so as not to use the broadcasting rule defined in JuliaDiff/ChainRules.jl#644.

Without this, the rrule for broadcasted(f, args...) is applied before the @adjoint rule for broadcasted(:: AbstractArrayStyle , f, args...). And then many tests here fail.

To stop the new rrule being used on older Zygote versions, I think they will need an upper bound on CR added to the registry.

@@ -18,6 +18,9 @@ such that if a suitable rule is defined later, the generated function will recom
function has_chain_rrule(T)
config_T, arg_Ts = Iterators.peel(T.parameters)
configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...})

isnothing(configured_rrule_m) && return false, nothing # too crude, surely
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something is wrong with the handling of methods produced by @opt_out. Without this line, you get the following error:

julia> gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)  # first test from "tricky broadcasting" in features.jl
ERROR: type Nothing has no field method
Stacktrace:
  [1] getproperty(x::Nothing, f::Symbol)
    @ Base ./Base.jl:37
  [2] has_chain_rrule(T::Type)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/chainrules.jl:24
  [3] #s263#1180
    @ ./compiler/interface2.jl:20 [inlined]
  [4] var"#s263#1180"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [5] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:585
  [6] _pullback
    @ ./REPL[13]:1 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::var"#449#450", args::Tuple{Int64, Int64})
    @ Zygote ./compiler/interface2.jl:0
  [8] _pullback(f::Function, args::Tuple{Int64, Int64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:34
  [9] pullback(f::Function, args::Tuple{Int64, Int64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:40
 [10] gradient(f::Function, args::Tuple{Int64, Int64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:75
 [11] top-level scope
    @ REPL[13]:1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paging @mzgubic who I think knows about this story. What's the right way to handle this nothing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump... either @mzgubic or @oxinabox wrote this, I think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I've forked your branch and commented out the line above, but can't reproduce this behaviour.

(opt_out_broadcasting) pkg> st
      Status `~/JuliaEnvs/test/opt_out_broadcasting/Project.toml`
  [082447d4] ChainRules v1.39.1
  [d360d2e6] ChainRulesCore v1.15.3 `dev/ChainRulesCore`
  [e88e6eb3] Zygote v0.6.41 `dev/Zygote`

julia> using Zygote

julia> using ChainRulesCore

julia> using ChainRules

julia> gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
true

Copy link
Member Author

@mcabbott mcabbott Jul 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I see no error in this configuration. ChainRules v1.39.1 is the tagged version, which does not have a rule matching this PR's @opt_out.

The error is when using this PR with JuliaDiff/ChainRules.jl#644 (now rebased to 1.39.1 too) and still happens for me.

Copy link
Collaborator

@mzgubic mzgubic Jul 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, wow, missed the context entirely. In any case, I think what is happening is:

julia> config_T = Zygote.ZygoteRuleConfig{Zygote.Context};

julia> arg_Ts = Any[typeof(Base.Broadcast.broadcasted), typeof(+), Tuple{Int64, Int64}, Matrix{Float64}];

julia> configured_rrule_m = IRTools.meta(Tuple{typeof(rrule), config_T, arg_Ts...})
# should not be nothing

# and the reason is the method error here
julia> rrule(Zygote.ZygoteRuleConfig(), Base.Broadcast.broadcasted, +, (1, 2), rand(2, 2))
ERROR: MethodError: rrule(::Zygote.ZygoteRuleConfig{Zygote.Context}, ::typeof(Base.Broadcast.broadcasted), ::typeof(+), ::Tuple{Int64, Int64}, ::Matrix{Float64}) is ambiguous. Candidates:
  rrule(::RuleConfig{>:HasReverseMode}, ::typeof(Base.Broadcast.broadcasted), ::typeof(+), xs::Union{Number, Base.Broadcast.Broadcasted, AbstractArray{<:Number}, Tuple{Vararg{Number, var"#s3204"}} where var"#s3204"}...) in ChainRules at /Users/mzgubic/JuliaEnvs/test/opt_out_broadcasting/dev/ChainRules/src/rulesets/Base/broadcast.jl:157
  rrule(cfg::Zygote.ZygoteRuleConfig, ::typeof(Base.Broadcast.broadcasted), f::F, args::Vararg{Any, N}) where {F, N} in Zygote
Possible fix, define
  rrule(::Zygote.ZygoteRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(+), ::Vararg{Union{Number, Base.Broadcast.Broadcasted, AbstractArray{<:Number}, Tuple{Vararg{Number, var"#s3204"}} where var"#s3204"}, N}) where N
Stacktrace:
 [1] top-level scope
   @ REPL[54]:1

perhaps we could throw a Method error directly? E.g. such as in https://github.com/JuliaDiff/ChainRulesTestUtils.jl/blob/6925da14c12e3d743c8d3620db8a8bee1433d5c3/src/testers.jl#L128

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I didn't realise there was an ambiguity. A nicer error would have helped me, but...

The ambiguity is this: I want to skip the "generic" CR rule, so I call this, and the new method is less specific than some other rrules for specific operations like +:

julia> @macroexpand1 ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
quote
    ((ChainRulesCore).no_rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any, N}) where {F, N}) = ChainRulesCore.nothing
    ((ChainRulesCore).rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any, N}) where {F, N}) = ChainRulesCore.nothing
end

Those more specific rules aren't so different from Zygote's, so it may not matter whose rule is called. But ideally this PR should probably opt out of all of them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this won't work. We can't opt out of e.g. the broadcasting rules NNlib defines. We need another plan here.

@mcabbott mcabbott added the ChainRules adjoint -> rrule, and further integration label Jul 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants