Skip to content

Commit

Permalink
Rules for foldl and accumulate (#526)
Browse files Browse the repository at this point in the history
* rrule for foldl + tests

* rrule for accumulate + tests

* rrule for cumsum + tests

* rule for sum(::Tuple)

* tests + tweaks

* rm cumsum

* comments

* rm comments + old tests

* test fixes

* skip tuples on 1.0

* version bump

* two suggestions, no more pi

* tidying

* updates to use Tuple ProjectTo, comments, tidying

* more

* fixes

* one more

* fixup for 1.0

* fix 1.0, comment

* fix 1.6 too?

* one more
  • Loading branch information
mcabbott authored Nov 14, 2021
1 parent edf3a1f commit a751937
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 14 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.12.1"
version = "1.13.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,7 +11,7 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1.1"
ChainRulesCore = "1.10"
ChainRulesTestUtils = "1"
Compat = "3.35"
FiniteDifferences = "0.12.8"
Expand Down
152 changes: 152 additions & 0 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
##### `sum(x)`
#####

function frule((_, ẋ), ::typeof(sum), x::Tuple)
return sum(x), sum(ẋ)
end
function frule((_, ẋ), ::typeof(sum), x; dims=:)
return sum(x; dims=dims), sum(ẋ; dims=dims)
end
Expand Down Expand Up @@ -324,3 +327,152 @@ end
end
return dx
end

#####
##### `foldl`
#####

# `foldl` guarantees to execute `f` in order, left to right. So it makes sense even when
# this `f` is stateful, in which case the gradient must be calculated in the reverse order.

# The implementation aims to be efficient for both tuples and arrays, although using accumulate
# to carry intermediate results along creates arrays of tuples which could be avoided; using a
# loop can be a few times faster. Note also that it does not return a gradient for `init`.

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(foldl), op::G, x::Union{AbstractArray, Tuple};
init=_InitialValue()
) where {G}
list, start = if init === _InitialValue()
_drop1(x), first(x)
else
# Case with init keyword is simpler to understand first!
_reshape1(x, :), init # (vec is for Julia 1.0, accumulate is fussy)
end
hobbits = accumulate(list; init=(start, nothing)) do (a,_), b
# Here `a` is what we would normally cary forward, and `_` ignores
# the previous iteration's pullback function (needed later),
# while `b` is the fresh input from `list` as usual.
c, back = rrule_via_ad(config, op, a, b) # LHS is just documentation here!
# We don't really need to store every `c`, last one is `foldl` output.
# (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
end
y = first(last(hobbits))
axe = axes(x)
project = ProjectTo(x)
function unfoldl(dy)
trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
# Don't need to store every `da`, need one for the next iteration + maybe last
end
dop = sum(first, trio)
dx = map(last, _reverse1(trio))
if init === _InitialValue()
# `hobbits` is one short
dx = _vcat1(trio[end][2], dx)
end
return (NoTangent(), dop, project(_reshape1(dx, axe)))
end
return y, unfoldl
end


#####
##### Iterator-or-Tuple functions
#####

# This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays,
# and also provides some alternatives for versions of Julia where iterators weren't supported.
# Inspired by `Base._reverse`, used in defn of `foldr`.

# To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps
# be replaced by _peel1 like Iterators.peel

if VERSION >= v"1.6"
_reverse1(x) = Iterators.reverse(x)
_drop1(x) = Iterators.drop(x, 1)
_zip2(x, y) = zip(x, y) # for `accumulate`, below
else
# Old versions don't support accumulate(::itr), nor multi-dim reverse
_reverse1(x) = reverse(vec(x))
_drop1(x) = vec(x)[2:end]
_zip2(x, y) = collect(zip(x, y))
end
_reverse1(x::Tuple) = reverse(x)
_drop1(x::Tuple) = Base.tail(x)
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)

struct _InitialValue end # Old versions don't have `Base._InitialValue`

_vcat1(x, ys::AbstractVector) = vcat(x, ys)
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)
_vcat1(x, ys::Tuple) = (x, ys...)

_reshape1(x::AbstractArray, axe) = reshape(x, axe)
_reshape1(x::Tuple, axe) = x

_no_tuple_tangent(dx::Tangent) = ChainRulesCore.backing(dx)
_no_tuple_tangent(dx) = dx


#####
##### `accumulate`
#####

# Like `foldl` this by definition works in order, so it makes sense to allow stateful `f`.

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple};
init=_InitialValue(), dims=nothing
) where {G}
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw(
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry"
# It's not supported by AD either, so no point calling back, and no regression:
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4))
# ERROR: Mutating arrays is not supported
)
list, start = if init === _InitialValue()
_drop1(x), first(x)
else
x, init
end
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
c, back = rrule_via_ad(config, op, a, b)
end
y = map(first, hobbits)
if init === _InitialValue()
# `hobbits` is one short, and first one doesn't invoke `op`
y = _vcat1(first(x), y)
end
axe = axes(x)
project = ProjectTo(x)
function decumulate(dy)
dy_plain = _no_tuple_tangent(unthunk(dy))
rev_list = if init === _InitialValue()
if VERSION >= v"1.6"
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
else
# However, on 1.0 and some others, zip does not stop early. But since accumulate
# also doesn't work on iterators, `_drop1` doesn't make one, so this should work:
_zip2(_reverse1(hobbits), _reverse1(_drop1(dy_plain)))
# What an awful tangle.
end
else
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
end
trio = accumulate(rev_list; init=(0, ZeroTangent(), 0)) do (_, dc, _), ((_, back), dz)
ds, da, db = back(dc + dz)
# Don't need to store every 'da', but need for next iteration, and the last one.
end
dop = sum(first, trio)
dx = map(last, _reverse1(trio))
if init == _InitialValue()
# `hobbits` is one short, and the first one is weird
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
end
return (NoTangent(), dop, project(_reshape1(dx, axe)))
end
return _reshape1(y, axe), decumulate
end
134 changes: 125 additions & 9 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights)
struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end

@testset "Maps and Reductions" begin
const CFG = ChainRulesTestUtils.ADviaRuleConfig()

@testset "Reductions" begin
@testset "sum(::Tuple)" begin
test_frule(sum, Tuple(rand(5)))
end
@testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3))
# Forward
test_frule(sum, rand(5); fkwargs=(;dims=dims))
Expand Down Expand Up @@ -79,12 +84,11 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
test_rrule(sum, inv, transpose(view(x, 1, :)))

# Make sure we preserve type for StaticArrays
ADviaRuleConfig = ChainRulesTestUtils.ADviaRuleConfig
_, pb = rrule(ADviaRuleConfig(), sum, abs, @SVector[1.0, -3.0])
_, pb = rrule(CFG, sum, abs, @SVector[1.0, -3.0])
@test pb(1.0) isa Tuple{NoTangent, NoTangent, SVector{2, Float64}}

# make sure we preserve type for Diagonal
_, pb = rrule(ADviaRuleConfig(), sum, abs, Diagonal([1.0, -3.0]))
_, pb = rrule(CFG, sum, abs, Diagonal([1.0, -3.0]))
@test pb(1.0)[3] isa Diagonal

# Boolean -- via @non_differentiable, test that this isn't ambiguous
Expand Down Expand Up @@ -173,7 +177,64 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
@test unthunk(rrule(prod, v)[2](1f0)[2]) == zeros(4)
test_rrule(prod, v)
end
end # prod
end # prod

@testset "foldl(f, ::Array)" begin
# Simple
y1, b1 = rrule(CFG, foldl, *, [1, 2, 3]; init=1)
@test y1 == 6
b1(7) == (NoTangent(), NoTangent(), [42, 21, 14])

y2, b2 = rrule(CFG, foldl, *, [1 2; 0 4]) # without init, needs vcat
@test y2 == 0
b2(8) == (NoTangent(), NoTangent(), [0 0; 64 0]) # matrix, needs reshape

# Test execution order
c5 = Counter()
y5, b5 = rrule(CFG, foldl, c5, [5, 7, 11])
@test c5 == Counter(2)
@test y5 == ((5 + 7)*1 + 11)*2 == foldl(Counter(), [5, 7, 11])
@test b5(1) == (NoTangent(), NoTangent(), [12*32, 12*42, 22])
@test c5 == Counter(42)

c6 = Counter()
y6, b6 = rrule(CFG, foldl, c6, [5, 7, 11], init=3)
@test c6 == Counter(3)
@test y6 == (((3 + 5)*1 + 7)*2 + 11)*3 == foldl(Counter(), [5, 7, 11], init=3)
@test b6(1) == (NoTangent(), NoTangent(), [63*33*13, 43*13, 23])
@test c6 == Counter(63)

# Test gradient of function
y7, b7 = rrule(CFG, foldl, Multiplier(3), [5, 7, 11])
@test y7 == foldl((x,y)->x*y*3, [5, 7, 11])
@test b7(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2310,), [693, 495, 315])

y8, b8 = rrule(CFG, foldl, Multiplier(13), [5, 7, 11], init=3)
@test y8 == 2_537_535 == foldl((x,y)->x*y*13, [5, 7, 11], init=3)
@test b8(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 585585,), [507507, 362505, 230685])
# To find these numbers:
# ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
# ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string

# Finite differencing
test_rrule(foldl, /, 1 .+ rand(3,4))
test_rrule(foldl, *, rand(ComplexF64,3,4); fkwargs=(; init=rand(ComplexF64)))
test_rrule(foldl, +, rand(ComplexF64,7); fkwargs=(; init=rand(ComplexF64)))
test_rrule(foldl, max, rand(3); fkwargs=(; init=999))
end
VERSION >= v"1.5" && @testset "foldl(f, ::Tuple)" begin
y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1)
@test y1 == 6
b1(7) == (NoTangent(), NoTangent(), Tangent{NTuple{3,Int}}(42, 21, 14))

y2, b2 = rrule(CFG, foldl, *, (1, 2, 0, 4))
@test y2 == 0
b2(8) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(0, 0, 64, 0))

# Finite differencing
test_rrule(foldl, /, Tuple(1 .+ rand(5)))
test_rrule(foldl, *, Tuple(rand(ComplexF64, 5)))
end
end

@testset "Accumulations" begin
Expand All @@ -188,14 +249,14 @@ end
@testset "higher dimensions, dims=$dims" for dims in (1,2,3)
m = round.(10 .* randn(4,5), sigdigits=3)
test_rrule(cumprod, m; fkwargs=(;dims=dims), atol=0.1)
m[2,2] = 0
m[2,4] = 0
m[2, 2] = 0
m[2, 4] = 0
test_rrule(cumprod, m; fkwargs=(;dims=dims))

t = round.(10 .* randn(3,3,3), sigdigits=3)
test_rrule(cumprod, t; fkwargs=(;dims=dims))
t[2,2,2] = 0
t[2,3,3] = 0
t[2, 2, 2] = 0
t[2, 3, 3] = 0
test_rrule(cumprod, t; fkwargs=(;dims=dims))
end

Expand All @@ -211,5 +272,60 @@ end
back = rrule(cumprod, Diagonal([1, 2]); dims=1)[2]
@test unthunk(back(fill(0.5, 2, 2))[2]) [1/2 0; 0 0] # ProjectTo'd to Diagonal now
end
end # cumprod

@testset "accumulate(f, ::Array)" begin
# Simple
y1, b1 = rrule(CFG, accumulate, *, [1, 2, 3, 4]; init=1)
@test y1 == [1, 2, 6, 24]
@test b1([1, 1, 1, 1]) == (NoTangent(), NoTangent(), [33, 16, 10, 6])

if VERSION >= v"1.5"
y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4])
@test y2 accumulate(/, [1 2; 3 4])
@test b2(ones(2, 2))[3] [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
end

# Test execution order
c3 = Counter()
y3, b3 = rrule(CFG, accumulate, c3, [5, 7, 11]; init=3)
@test c3 == Counter(3)
@test y3 == [8, 30, 123] == accumulate(Counter(), [5, 7, 11]; init=3)
@test b3([1, 1, 1]) == (NoTangent(), NoTangent(), [29169, 602, 23]) # the 23 is clear!

c4 = Counter()
y4, b4 = rrule(CFG, accumulate, c4, [5, 7, 11])
@test c4 == Counter(2)
@test y4 == [5, (5+7)*1, ((5+7)*1 + 11)*2] == accumulate(Counter(), [5, 7, 11])
@test b4([1, 1, 1]) == (NoTangent(), NoTangent(), [417, 42*(1 + 12), 22])

# Test gradient of function
y7, b7 = rrule(CFG, accumulate, Multiplier(3), [5, 7, 11])
@test y7 == accumulate((x,y)->x*y*3, [5, 7, 11])
@test b7([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2345,), [715, 510, 315])

y8, b8 = rrule(CFG, accumulate, Multiplier(13), [5, 7, 11], init=3)
@test y8 == [195, 17745, 2537535] == accumulate((x,y)->x*y*13, [5, 7, 11], init=3)
@test b8([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 588330,), [511095, 365040, 230685])
# To find these numbers:
# ForwardDiff.derivative(z -> sum(accumulate((x,y)->x*y*z, [5,7,11], init=3)), 13)
# ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string

# Finite differencing
test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand()))
if VERSION >= v"1.5"
test_rrule(accumulate, /, 1 .+ rand(3, 4))
test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
end
end
VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin
# Simple
y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
@test y1 == (1, 2, 6, 24)
@test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))

# Finite differencing
test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
end
end
Loading

2 comments on commit a751937

@mzgubic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/49341

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.13.0 -m "<description of version>" a75193768775975fac5578c89d1e5f50d7f358c2
git push origin v1.13.0

Please sign in to comment.