-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhancement proposal: Modular tape caching #234
Comments
Use the |
Thanks for the reply. Forgive me for not understanding fully, do you think you could expand a little on how |
It would not be compiled by default but you can choose to compile 2 different tapes, one for each branch. I think you might also be able to do that lazily. |
Could you give me a starting point that I could expand on? I'm not too familiar with |
It's not easy. If you are already familiar with ReverseDiff, try reading https://github.com/JuliaDiff/AbstractDifferentiation.jl/blob/master/ext/AbstractDifferentiationReverseDiffExt.jl to understand the AD API. Then you will need to address JuliaDiff/AbstractDifferentiation.jl#41. Then it should be easy to do a MWE. If you are interested to spend time on this, we can schedule a call to go through the work required to get it done. |
I had a quick read of the above links as well as the AbstractDifferentiation PR about ReverseDiff. I see that it's a relatively difficult problem to solve at such a high-level (for all backends) due to type stability. I'd be interested in working on it. |
Saw the GSoC idea this proposal is referring to, very interesting stuff. One question from me: would this help with being able to represent dynamically-bounded loops on the tape without requiring recompilation? I can think of a few cases related to sequence/time series modelling where it would be nice to not eat tracing + tape compilation latency every time the input length changes. Some mechanism for caching sub-tapes seems like a necessary prerequisite for that, but I'm not sure if it falls under the scope of this proposal. |
Base on my (limited) understanding of the problem I think the answer is no. That said, Mohamed may have a better idea to deal with it. Maybe Julia can do more than JAX in this regard? |
If you have a specific example, we can think about it. |
The ultimate use case I have in mind is a RNN, but here is a simpler dependency-free example: function f(xs)
s = zero(eltype(xs))
for (i, x) in enumerate(xs)
s += i * x
end
return s
end
julia> tape = ReverseDiff.GradientTape(f, ones(5))
typename(ReverseDiff.GradientTape)(f)
julia> ReverseDiff.gradient!(tape, ones(5))
5-element Vector{Float64}:
1.0
2.0
3.0
4.0
5.0
julia> ReverseDiff.gradient!(tape, ones(3))
5-element Vector{Float64}:
1.0
2.0
3.0
4.0
5.0
julia> ReverseDiff.gradient!(tape, ones(10))
ERROR: BoundsError: attempt to access 5-element Vector{Float64} at index [1:10]
Stacktrace:
[1] throw_boundserror(A::Vector{Float64}, I::Tuple{UnitRange{Int64}})
@ Base ./abstractarray.jl:744
[2] checkbounds
@ ./abstractarray.jl:709 [inlined]
[3] _copyto_impl!(dest::Vector{Float64}, doffs::Int64, src::Vector{Float64}, soffs::Int64, n::Int64)
@ Base ./array.jl:325
[4] copyto!
@ ./array.jl:319 [inlined]
[5] copyto!
@ ./array.jl:342 [inlined]
[6] value!
@ ~/.julia/packages/ReverseDiff/wIfrd/src/tracked.jl:156 [inlined]
[7] seeded_forward_pass!
@ ~/.julia/packages/ReverseDiff/wIfrd/src/api/tape.jl:41 [inlined]
[8] gradient!
@ ~/.julia/packages/ReverseDiff/wIfrd/src/api/gradients.jl:79 [inlined]
[9] gradient!(tape::ReverseDiff.GradientTape{typeof(f), ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, input::Vector{Float64})
@ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/api/gradients.jl:63
[10] top-level scope
@ REPL[18]:1 It would be nice to have a way to specify "don't unroll this loop" when tracing so that the same tape could be re-used for different input lengths. |
For loops are not possible to intercept with ReverseDiff because they are not functions but if wrapped in a function, the function can be intercepted. In this case, you can define an rrule for this function which calls RD with no tape caching. This is possible now with AbstractDifferentiation. for (i, x) in enumerate(xs)
s += i * x
end |
Thanks Mohamed. I'm aware of the custom rule path, but the hope was to make use of tape caching (or I'd resort to using Zygote). Perhaps this example better describes my motivation: function scan(f, xs, init)
ys = empty(xs)
h = init
for x in xs
h, y = f(h, x)
push!(ys, y)
end
return h, ys
end @jacobusmmsmit probably recognizes this as |
After a call with Mohamed, I think the implementation we decided to try out will address
[API for this part a work in progress, I mean so is this whole thing but this part especially] struct CachedReverseDiffBackend{F, T} # Could also be parametric in backend type
func::F
compiled_tape::T
# Constructor to compile the tape given inputs
function CachedReverseDiffBackend(f::F, x) where {F}
compiled_tape = compile(construct_tape(f, x)) # pseudo RD code
T = typeof(compiled_tape)
return new{F, T}(f, compiled_tape)
end
end
compiled_f = CachedReverseDiffBackend(f, x) # where typeof(x) == eltype(xs) then we make the cached backend callable (with the caveat that const CRDB = CachedReverseDiffBackend # alias for brevity
(b::CRDB)(y) = call_func(b, y)
call_func(b::CRDB, y) = b.func(y) and define a custom rule for our function call_func(b::CRDB, y::TrackedArray)
return ReverseDiff.track(call_func, b, y)
end
import AbstractDifferentiation as AD
ReverseDiff.@grad function call_func(b::CRDB, y)
return AD.value_and_pullback_function(b, y) # to be implemented
end Now we can pass So in the end we have an outer uncompiled tape which contains calls to inner compiled tapes. |
My previous comment was discussing the compiled tape in an uncompiled tape case, but the uncompiled tape in a compiled tape is easier to address. I'm leaving this comment as some documentation of how this is already possible but could use some development to make it easier to use. At the end I do have a question of how Example showing it's already possibleSome setup using ReverseDiff
using ReverseDiff: TrackedArray, track, @grad, value, GradientTape, compile, gradient!, gradient First we define a function with branches. Compiling a tape with branches on it is currently a very dangerous operation as it will compile without complaining but silently return the wrong answer. branching_f(x) = sum(x) > 1 ? sum(x)^ 3 : sum(x)^2
_branching_f(x) = sum(x) > 1 ? sum(x)^ 3 : sum(x)^2 # function used as a reference Then we define a custom gradient with some logging to show that the right thing is happening each time. branching_f(x::TrackedArray) = track(branching_f, x)
@grad function branching_f(x)
xv = value(x)
function grad_branching(Δ)
@show sum(xv)
if sum(xv) > 1
println("High branch")
return (3*sum(xv)^2*Δ, )
else
println("Low branch")
return (2*sum(xv)*Δ, )
end
end
return branching_f(xv), grad_branching
end Now we construct the tapes and test that everything is running as expected: # Construct and compile the tape
input = [0.0, 1.1, 1.0]
branching_tape = compile(GradientTape(branching_f, input))
_branching_tape = compile(GradientTape(_branching_f, input)) # This tape should ignore the branch
# One input for each branch in the function
input_low = [0.1, 0.2, 0.3]
input_high = [1.1, 1.2, 1.3]
# Test for correctness of implementation
grad_low = gradient(_branching_f, input_low)
grad_high = gradient(_branching_f, input_high)
grad_low == gradient(branching_f, input_low)
grad_high == gradient(branching_f, input_high)
# An example of the method working
grad_low == gradient!(branching_tape, input_low) # true
grad_low == gradient!(_branching_tape, input_low) # false
grad_high == gradient!(branching_tape, input_high) # true
grad_high == gradient!(_branching_tape, input_high) # true (but for the wrong reason) Where to go from hereSo, in a way, there we go. We can do modular tape caching already! But this is all very manual. It would be very nice we could have this done automatically such as: Automatic detection of branches and a warning julia> compile(GradientTape(_branching_tape, input))
Warning: woah buddy, you've got a branch in that function of yours, I don't think you want to compile it! or automatic detection of branches and not compiling the branch sources (not ideal) julia> compile(GradientTape(outer_function_with_inner_branch, my_input)) # Automatic modularisation
Warning: The tape of `outer_function_with_inner_branch` has branches because of `inner_function`,
this function was not compiled or allowing users to define static arguments à la JAX inner_function(x, y) = x > 0 : 2y : 3y^2
sa_inner_function = @static_arguments(inner_function, x)
outer_function_with_inner_branch(z) = sum(z) * sa_inner_function(z[1], z[2]) or ultimately automatic detection of branches and not compiling the branch sources with respect to those arguments inner_function(x, y) = x > 0 : 2y : 3y^2
outer_function_with_inner_branch(z) = sum(z) * sa_inner_function(z[1], z[2])
compile(GradientTape(outer_function_with_inner_branch, my_input)) # All good, works as if it were uncompiled but with compiled performance where possible. A questionWhat I'd like to ask is about how @grad function branching_f(x)
xv = value(x)
sum_xv = sum(xv) # This part is constant when compiled
function grad_branching(Δ)
(sum_xv > 1 ? 3*sum_xv^2*Δ : 2*sum_xv*Δ,) # Doesn't work at all
end
return branching_f(xv), grad_branching
end |
Ok, I've got a draft implementation for defining cached sub-tapes: import AbstractDifferentiation as AD
using ReverseDiff
using ReverseDiff: @grad, compile, GradientTape
import AbstractDifferentiation: primal_value, pullback_function, value_and_pullback_function
struct CachedReverseDiffBackend{F,T} <: AD.AbstractBackend# Could also be parametric in backend type
func::F
compiled_tape::T
# Constructor to compile the tape given inputs
function CachedReverseDiffBackend(f::F, x) where {F}
compiled_tape = compile(GradientTape(f, x)) # pseudo RD code
T = typeof(compiled_tape)
return new{F,T}(f, compiled_tape)
end
end
const CRDB = CachedReverseDiffBackend # alias for brevity
(b::CRDB)(x) = call_func(b, x)
call_func(b::CRDB, x) = b.func(x)
function call_func(b::CRDB, x::ReverseDiff.TrackedArray)
return ReverseDiff.track(call_func, b, x)
end
@grad function call_func(b::CRDB, x)
return value_and_pullback_function(b, x)
end
primal_value(::CRDB, xs, _) = primal_value(xs) # is this ok?
function value_and_pullback_function(cb::CRDB, x)
xv = ReverseDiff.value(x)
yv = cb.func(xv)
function pullback_f(Δ)
(Δ*ReverseDiff.gradient!(cb.compiled_tape, xv), ) # no space to cache output :/
end
return yv, pullback_f
end Should this backend be a real backend i.e. should it define a Here's an example of how it would be used: using BenchmarkTools
g(xs) = sum(abs2, xs)
xs = [1.0, 2.0, 3.0]
const crdb = CRDB(g, xs) # must be declared const otherwise type unstable when called
gt = compile(GradientTape(g, xs)) # RD code
# Check gradients work as intended :)
ReverseDiff.gradient(g, xs .+ 1)
ReverseDiff.gradient!(gt, xs .+ 1)
ReverseDiff.gradient!(crdb.compiled_tape, xs .+ 1)
# All return the same thing
# Define an outer function
f_nocompile(xs) = 2g(xs) # use the original `g`
f_compile(xs) = 2crdb(xs) # use the `g` with a compiled gradient
# Primal timings
@btime f_nocompile($xs) # 4.000 ns (0 allocations: 0 bytes)
@btime f_compile($xs) # 4.000 ns (0 allocations: 0 bytes)
# Gradient timings
@btime ReverseDiff.gradient(f_nocompile, $xs) # 961.750 ns (32 allocations: 1.34 KiB)
@btime ReverseDiff.gradient(f_compile, $xs) # 1.092 μs (17 allocations: 1008 bytes)
# Double-compile also works
fnc_tape = compile(GradientTape(f_nocompile, xs))
fc_tape = compile(GradientTape(f_compile, xs))
@btime ReverseDiff.gradient!(fnc_tape, $xs) # 521.266 ns (1 allocation: 80 bytes)
@btime ReverseDiff.gradient!(fc_tape, $xs) # 847.889 ns (3 allocations: 240 bytes) As talked about in this issue, caching interfaces should be addressed as this is, I think, where the performance difference comes from. |
First, sorry for the really late response.
Correct. Documentation PRs are welcome :) The function value_and_pullback_function(cb::CRDB, x)
xv = ReverseDiff.value(x)
yv = cb.func(xv)
function pullback_f(Δ)
(Δ*ReverseDiff.gradient!(cb.compiled_tape, xv), ) # no space to cache output :/
end
return yv, pullback_f
end
Your implementation right now seems to only work for scalar-valued functions. You might want to generalise it and then yes making it a primitive will give you all the other methods for free. Check the ReverseDiff backend implementation in AbstractDifferentiation for reference.
Try profiling to see where the performance difference comes from. Also try a function with more inputs which might be more representative of when people use ReverseDiff. Most people would not use ReverseDiff for a function of 3 variables. If allocations are the bottleneck in your function, then we need to consider reducing those but let's check first that: 1) that's the case with profiling, and 2) that's a real problem you will run into when using the package for real sized problems. |
Problem
Compilation can't be used with run-time control flow. This stops some code from taking advantage of tape compilation.
Possible solution
Enable ReverseDiff's tape caching functionality to be used in cases with run-time control flow by introducing guarded/sub-tapes which are recompiled automatically if the instructions they contain are invalidated by a user-specified guard statement.
My implementation idea is that these guarded/sub-tapes live directly on normal compiled tapes as another type of
AbstractInstruction
(if I'm correct in assuming that it doesn't fit insideSpecialInstruction
).Here's a quick-and-dirty non-implementation showcasing the idea in action:
The soul of this is borrowed from JAX's
static_argnums
/static_argnames
injit
, where users can specify an argument(s) that, if changed, triggers the lookup/recompilation step. This is essentially value dispatch. I'm not sure on its performance implications.Impact
The original context this project is the Turing package. Gradient-based methods like HMC and NUTS are the state-of-the-art for MCMC sampling and, as stated on Turing's GSoC projects page, their performance is greatly improved by the caching features of ReverseDiff. However, this is not universally applicable and more complicated models using other packages will normally contain unavoidable control flow.
More generally, the ability to efficiently differentiate through control flow will allow ReverseDiff to be more universally recommended in packages that rely on ForwardDiff. AD backend selection is a great feature in the SciML ecosystem, and many of its packages, such as Optimization, could benefit from this contribution by making AD backend selection a potential performance footgun as opposed to a (admittedly blatant but not trivial) correctness one.
While next generation AD backends such as Diffractor and Enzyme are a hot topic in the ecosystem at the moment, ReverseDiff is a package which has stood the test of time for its reliability and performance. For workloads such as those found in Turing, "out of the box" it is almost always faster than Zygote, especially in compiled mode. Zygote may sometimes be faster, but requires far more hand-tuning to reach the necessary speeds, most of which is inaccessible to end-users.
ReverseDiff has a clear niche in the AD backend ecosystem: its target users are moderately performance sensitive with medium-to-high dimensional problems and it covers these very well with little to no hand-tuning. While Enzyme has incredible performance, which is a feature for the most performance-critical applications, it is neither trivial to use and tune, nor can it be applied in every situation due to some compatibility issues. In a similar vein, Zygote is a high performance solution that works great for applications heavy in linear algebra, but often requires significant hand-tuning.
The text was updated successfully, but these errors were encountered: