-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add opting out of rules #398
Changes from 4 commits
23ec91d
8fa0ecb
c542c2f
a8ffabc
a66f5e7
db35df7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Opting out of rules | ||
|
||
It is common to define rules fairly generically. | ||
Often matching (or exceeding) how generic the matching original primal method is. | ||
Sometimes this is not the correct behaviour. | ||
Sometimes the AD can do better than this human defined rule. | ||
If this is generally the case, then we should not have the rule defined at all. | ||
But if it is only the case for a particular set of types, then we want to opt-out just that one. | ||
This is done with the [`@opt_out`](@ref) macro. | ||
|
||
Consider one a `rrule` for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself) | ||
```julia | ||
function rrule(::typeof(sum), x::AbstractArray{<:Number}; dims=:) | ||
y = sum(x; dims=dims) | ||
project = ProjectTo(x) | ||
function sum_pullback(ȳ) | ||
# broadcasting the two works out the size no-matter `dims` | ||
# project makes sure we stay in the same vector subspace as `x` | ||
# no putting in off-diagonal entries in Diagonal etc | ||
x̄ = project(broadcast(last∘tuple, x, ȳ))) | ||
return (NoTangent(), x̄) | ||
end | ||
return y, sum_pullback | ||
end | ||
``` | ||
|
||
That is a fairly reasonable `rrule` for the vast majority of cases. | ||
|
||
You might have a custom array type for which you could write a faster rule. | ||
For example, the pullback for summing a [`SkewSymmetric` (anti-symmetric)](https://en.wikipedia.org/wiki/Skew-symmetric_matrix) matrix can be optimized to basically be `Diagonal(fill(ȳ, size(x,1)))`. | ||
To do that, you can indeed write another more specific [`rrule`](@ref). | ||
But another case is where the AD system itself would generate a more optimized case. | ||
|
||
For example, the [`NamedDimsArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type. | ||
Its sum method is basically just to call `sum` on its parent. | ||
It is entirely conceivable[^1] that the AD system can do better than our `rrule` here. | ||
For example by avoiding the overhead of [`project`ing](@ref ProjectTo). | ||
|
||
To opt-out of using the generic `rrule` and to allow the AD system to do its own thing we use the | ||
[`@opt_out`](@ref) macro, to say to not use it for sum of `NamedDimsArrays`. | ||
|
||
```julia | ||
@opt_out rrule(::typeof(sum), ::NamedDimsArray) | ||
``` | ||
|
||
We could even opt-out for all 1 arg functions. | ||
```@julia | ||
@opt_out rrule(::Any, ::NamedDimsArray) | ||
``` | ||
Though this is likely to cause some method-ambiguities. | ||
|
||
Similar can be done `@opt_out frule`. | ||
It can also be done passing in a [`RuleConfig`](@ref config). | ||
|
||
|
||
## How to support this (for AD implementers) | ||
|
||
We provide two ways to know that a rule has been opted out of. | ||
|
||
### `rrule` / `frule` returns `nothing` | ||
|
||
`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`. | ||
|
||
If you are in a position to generate code, in response to values returned by function calls then you can do something like: | ||
```@julia | ||
res = rrule(f, xs) | ||
if res === nothing | ||
y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule | ||
else | ||
y, pullback = res | ||
end | ||
``` | ||
The Julia compiler will specialize based on inferring the return type of `rrule`, and so can remove that branch. | ||
|
||
### `no_rrule` / `no_frule` has a method | ||
|
||
`@opt_out` also defines a method for [`ChainRulesCore.no_frule`](@ref) or [`ChainRulesCore.no_rrule`](@ref). | ||
The body of this method doesn't matter, what matters is that it is a method-table. | ||
A simple thing you can do with this is not support opting out. | ||
To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table. | ||
This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` (and thus prevents your library from erroring). | ||
This is easily done, though it does mean ignoring the user's stated desire to opt out of the rule. | ||
|
||
More complex you can use this to generate code that triggers your AD. | ||
If for a given signature there is a more specific method in the `no_rrule`/`no_frule` method-table, than the one that would be hit from the `rrule`/`frule` table | ||
(Excluding the one that exactly matches which will return `nothing`) then you know that the rule should not be used. | ||
You can, likely by looking at the primal method table, workout which method you would have it if the rule had not been defined, | ||
and then `invoke` it. | ||
|
||
|
||
|
||
[^1]: It is also possible, that this is not the case. Benchmark your real uses cases. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -139,3 +139,52 @@ const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance | |
function (::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...) | ||
return rrule_kwfunc(kws, rrule, args...) | ||
end | ||
|
||
############################################################## | ||
### Opt out functionality | ||
|
||
const NO_RRULE_DOC = """ | ||
no_rrule | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we maybe use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not really internal I will instead put a warning in the docstring |
||
|
||
This is an implementation detail for opting out of [`rrule`](@ref). | ||
It follows the signature for `rrule` exactly. | ||
We use it as a way to store a collection of type-tuples in its method-table. | ||
If something has this defined, it means that it must having a must also have a `rrule`, | ||
that returns `nothing`. | ||
|
||
### Mechanics | ||
note: when the text below says methods `==` or `<:` it actually means: | ||
`parameters(m.sig)[2:end]` (i.e. the signature type tuple) rather than the method object `m` itself. | ||
|
||
To decide if should opt-out using this mechanism. | ||
- find the most specific method of `rrule` | ||
- find the most specific method of `no_rrule` | ||
- if the method of `no_rrule` `<:` the method of `rrule`, then should opt-out | ||
|
||
To just ignore the fact that rules can be opted-out from, and that some rules thus return | ||
`nothing`, then filter the list of methods of `rrule` to remove those that are `==` to ones | ||
that occur in the method table of `no_rrule`. | ||
|
||
Note also when doing this you must still also handle falling back from rule with config, to | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we expand on how to do this in the docs? Maybe just mentioning and seeing an example in Nabla or Zygote is enough There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we do expand on this in the docs already. |
||
rule without config. | ||
|
||
On the other-hand if your AD can work with `rrule`s that return `nothing`, then it is | ||
simpler to just use that mechanism for opting out; and you don't need to worry about this | ||
at all. | ||
""" | ||
|
||
""" | ||
$NO_RRULE_DOC | ||
|
||
See also [`ChainRulesCore.no_frule`](@ref). | ||
""" | ||
function no_rrule end | ||
no_rrule(::Any, ::Vararg{Any}) = nothing | ||
|
||
""" | ||
$(replace(NO_RRULE_DOC, "rrule"=>"frule")) | ||
|
||
See also [`ChainRulesCore.no_rrule`](@ref). | ||
""" | ||
function no_frule end | ||
no_frule(ȧrgs, f, ::Vararg{Any}) = nothing |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,4 +148,32 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) | |
@test_skip ∂xr isa Float64 # to be made true with projection | ||
@test_skip ∂xr ≈ real(∂x) | ||
end | ||
|
||
|
||
@testset "@opt_out" begin | ||
first_oa(x, y) = x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
@scalar_rule(first_oa(x, y), (1, 0)) | ||
@opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32 | ||
@opt_out( | ||
ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32 | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
@testset "rrule" begin | ||
@test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) | ||
@test rrule(first_oa, 3f0, 4f0) === nothing | ||
|
||
@test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m | ||
m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32 | ||
end) | ||
end | ||
|
||
@testset "frule" begin | ||
@test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1) | ||
@test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing | ||
|
||
@test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m | ||
m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32 | ||
end) | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this escaping will mean this doesn't work in anything that doesn't have ChainRulesCore imported.
It needs a test in the isolated scope testset.
But I am ok leaving that for a follow up PR.