Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: mapreduce type stability, VoA broadcast adjoints #325

Merged
merged 3 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 110 additions & 2 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@
end
end

@adjoint function Base.copy(u::VectorOfArray)

Check warning on line 103 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L103

Added line #L103 was not covered by tests
copy(u),
y -> (copy(y),)
end

@adjoint function DiffEqArray(u, t)
DiffEqArray(u, t),
y -> begin
Expand All @@ -117,19 +122,122 @@
A.x, literal_ArrayPartition_x_adjoint
end

@adjoint function Array(VA::AbstractVectorOfArray)
@adjoint function Base.Array(VA::AbstractVectorOfArray)

Check warning on line 125 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L125

Added line #L125 was not covered by tests
Array(VA),
y -> (Array(y),)
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)

Check warning on line 130 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L130

Added line #L130 was not covered by tests
view(A, I...),
y -> (view(y, I...), ntuple(_ -> nothing, length(I))...)
end

ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))

function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x)
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{AbstractArray,AbstractVectorOfArray})
arr = reshape(x, p.sz)
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
end

@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, y::Union{Zygote.Numeric, AbstractVectorOfArray})

Check warning on line 142 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L142

Added line #L142 was not covered by tests
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)
end
@adjoint function Broadcast.broadcasted(::typeof(+), x::Zygote.Numeric, y::AbstractVectorOfArray)
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)

Check warning on line 146 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L145-L146

Added lines #L145 - L146 were not covered by tests
end

_minus(Δ) = .-Δ
_minus(::Nothing) = nothing

Check warning on line 150 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L150

Added line #L150 was not covered by tests

@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})

Check warning on line 152 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L152

Added line #L152 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're never supposed to define specific broadcasts like this, and I don't think digging into Zygote's broadcast system is the answer. It just relies on the Julia-level broadcast

x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ)))
end
@adjoint function Broadcast.broadcasted(::typeof(*), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})

Check warning on line 155 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L155

Added line #L155 was not covered by tests
(
x.*y,
Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x)))
)
end
@adjoint function Broadcast.broadcasted(::typeof(/), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})

Check warning on line 161 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L161

Added line #L161 was not covered by tests
res = x ./ y
res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y)))
end
@adjoint function Broadcast.broadcasted(::typeof(-), x::Zygote.Numeric, y::AbstractVectorOfArray)

Check warning on line 165 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L165

Added line #L165 was not covered by tests
x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ)))
end
@adjoint function Broadcast.broadcasted(::typeof(*), x::Zygote.Numeric, y::AbstractVectorOfArray)

Check warning on line 168 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L168

Added line #L168 was not covered by tests
(
x.*y,
Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x)))
)
end
@adjoint function Broadcast.broadcasted(::typeof(/), x::Zygote.Numeric, y::AbstractVectorOfArray)
res = x ./ y
res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y)))

Check warning on line 176 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L174-L176

Added lines #L174 - L176 were not covered by tests
end
@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray)

Check warning on line 178 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L178

Added line #L178 was not covered by tests
.-x, Δ -> (nothing, _minus(Δ))
end

@adjoint function Broadcast.broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::AbstractVectorOfArray, exp::Val{p}) where p

Check warning on line 182 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L182

Added line #L182 was not covered by tests
y = Base.literal_pow.(^, x, exp)
y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing)
end

@adjoint Broadcast.broadcasted(::typeof(identity), x::AbstractVectorOfArray) = x, Δ -> (nothing, Δ)

@adjoint function Broadcast.broadcasted(::typeof(tanh), x::AbstractVectorOfArray)

Check warning on line 189 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L189

Added line #L189 was not covered by tests
y = tanh.(x)
y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2))
end

@adjoint Broadcast.broadcasted(::typeof(conj), x::AbstractVectorOfArray) =
conj.(x), z̄ -> (nothing, conj.(z̄))

@adjoint Broadcast.broadcasted(::typeof(real), x::AbstractVectorOfArray) =
real.(x), z̄ -> (nothing, real.(z̄))

@adjoint Broadcast.broadcasted(::typeof(imag), x::AbstractVectorOfArray) =
imag.(x), z̄ -> (nothing, im .* real.(z̄))

@adjoint Broadcast.broadcasted(::typeof(abs2), x::AbstractVectorOfArray) =
abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x)

@adjoint function Broadcast.broadcasted(::typeof(+), a::AbstractVectorOfArray{<:Number}, b::Bool)
y = b === false ? a : a .+ b
y, Δ -> (nothing, Δ, nothing)

Check warning on line 208 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L206-L208

Added lines #L206 - L208 were not covered by tests
end
@adjoint function Broadcast.broadcasted(::typeof(+), b::Bool, a::AbstractVectorOfArray{<:Number})
y = b === false ? a : b .+ a
y, Δ -> (nothing, nothing, Δ)

Check warning on line 212 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L210-L212

Added lines #L210 - L212 were not covered by tests
end

@adjoint function Broadcast.broadcasted(::typeof(-), a::AbstractVectorOfArray{<:Number}, b::Bool)
y = b === false ? a : a .- b
y, Δ -> (nothing, Δ, nothing)

Check warning on line 217 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L215-L217

Added lines #L215 - L217 were not covered by tests
end
@adjoint function Broadcast.broadcasted(::typeof(-), b::Bool, a::AbstractVectorOfArray{<:Number})
b .- a, Δ -> (nothing, nothing, .-Δ)

Check warning on line 220 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L219-L220

Added lines #L219 - L220 were not covered by tests
end

@adjoint function Broadcast.broadcasted(::typeof(*), a::AbstractVectorOfArray{<:Number}, b::Bool)
if b === false
zero(a), Δ -> (nothing, zero(Δ), nothing)

Check warning on line 225 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L223-L225

Added lines #L223 - L225 were not covered by tests
else
a, Δ -> (nothing, Δ, nothing)

Check warning on line 227 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L227

Added line #L227 was not covered by tests
end
end
@adjoint function Broadcast.broadcasted(::typeof(*), b::Bool, a::AbstractVectorOfArray{<:Number})
if b === false
zero(a), Δ -> (nothing, nothing, zero(Δ))

Check warning on line 232 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L230-L232

Added lines #L230 - L232 were not covered by tests
else
a, Δ -> (nothing, nothing, Δ)

Check warning on line 234 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L234

Added line #L234 was not covered by tests
end
end

@adjoint Broadcast.broadcasted(::Type{T}, x::AbstractVectorOfArray) where {T<:Number} =
T.(x), ȳ -> (nothing, Zygote._project(x, ȳ),)

Check warning on line 239 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L238-L239

Added lines #L238 - L239 were not covered by tests

function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
N = ndims(x̄)
if length(x) == length(x̄)
Expand Down
4 changes: 2 additions & 2 deletions src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x
## Iterable Collection Constructs

Base.map(f, A::ArrayPartition) = ArrayPartition(map(x -> map(f, x), A.x))
function Base.mapreduce(f, op, A::ArrayPartition)
mapreduce(f, op, (mapreduce(f, op, x) for x in A.x))
function Base.mapreduce(f, op, A::ArrayPartition{T}; kwargs...) where {T}
mapreduce(f, op, (i for i in A); kwargs...)
end
Base.filter(f, A::ArrayPartition) = ArrayPartition(map(x -> filter(f, x), A.x))
Base.any(f, A::ArrayPartition) = any(f, (any(f, x) for x in A.x))
Expand Down
44 changes: 35 additions & 9 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,16 @@
VectorOfArray{eltype(T), N, typeof(vec)}(vec)
end
# Assume that the first element is representative of all other elements
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
function VectorOfArray(vec::AbstractVector)
T = eltype(vec[1])
N = ndims(vec[1])
if all(x isa Union{<:AbstractArray, <:AbstractVectorOfArray} for x in vec)
A = Vector{Union{typeof.(vec)...}}
else
A = typeof(vec)
end
VectorOfArray{T, N + 1, A}(vec)
end
function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray{T, N}}
VectorOfArray{T, N + 1, typeof(vec)}(vec)
end
Expand Down Expand Up @@ -482,21 +491,30 @@
return VA
end

function Base.stack(VA::AbstractVectorOfArray; dims = :)
stack(VA.u; dims)
end

# AbstractArray methods
function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M}
@inline
J = map(i->Base.unalias(A,i), to_indices(A, I))
@boundscheck checkbounds(A, J...)
SubArray(IndexStyle(A), A, J, Base.index_dimsum(J...))
end
function Base.SubArray(parent::AbstractVectorOfArray, indices::Tuple)
@inline
SubArray(IndexStyle(Base.viewindexing(indices), IndexStyle(parent)), parent, Base.ensure_indexable(indices), Base.index_dimsum(indices...))

Check warning on line 507 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L505-L507

Added lines #L505 - L507 were not covered by tests
end
Base.isassigned(VA::AbstractVectorOfArray, idxs...) = checkbounds(Bool, VA, idxs...)

Check warning on line 509 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L509

Added line #L509 was not covered by tests
Base.check_parent_index_match(::RecursiveArrayTools.AbstractVectorOfArray{T,N}, ::NTuple{N,Bool}) where {T,N} = nothing
Base.ndims(::AbstractVectorOfArray{T, N}) where {T, N} = N
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
if checkbounds(Bool, VA.u, last(idx))
if last(idx) isa Integer
return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)))
return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)...))
else
return all(checkbounds.(Bool, VA.u[last(idx)], Base.front(idx)))
return all(checkbounds.(Bool, VA.u[last(idx)], tuple.(Base.front(idx))...))
end
end
return false
Expand Down Expand Up @@ -595,10 +613,14 @@
end

# statistics
@inline Base.sum(f, VA::AbstractVectorOfArray) = sum(f, Array(VA))
@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(Array(VA); kwargs...)
@inline Base.prod(f, VA::AbstractVectorOfArray) = prod(f, Array(VA))
@inline Base.prod(VA::AbstractVectorOfArray; kwargs...) = prod(Array(VA); kwargs...)
@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(identity, VA; kwargs...)
@inline function Base.sum(f, VA::AbstractVectorOfArray; kwargs...)
mapreduce(f, Base.add_sum, VA; kwargs...)
end
@inline Base.prod(VA::AbstractVectorOfArray; kwargs...) = prod(identity, VA; kwargs...)
@inline function Base.prod(f, VA::AbstractVectorOfArray; kwargs...)
mapreduce(f, Base.mul_prod, VA; kwargs...)
end

@inline Statistics.mean(VA::AbstractVectorOfArray; kwargs...) = mean(Array(VA); kwargs...)
@inline function Statistics.median(VA::AbstractVectorOfArray; kwargs...)
Expand Down Expand Up @@ -638,8 +660,12 @@
end

Base.map(f, A::RecursiveArrayTools.AbstractVectorOfArray) = map(f, A.u)
function Base.mapreduce(f, op, A::AbstractVectorOfArray)
mapreduce(f, op, (mapreduce(f, op, x) for x in A.u))

function Base.mapreduce(f, op, A::AbstractVectorOfArray; kwargs...)
mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
end
function Base.mapreduce(f, op, A::AbstractVectorOfArray{T,1,<:AbstractVector{T}}; kwargs...) where {T}
mapreduce(f, op, A.u; kwargs...)
end

## broadcasting
Expand Down
23 changes: 22 additions & 1 deletion test/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,27 @@ end

function loss7(x)
_x = VectorOfArray([x .* i for i in 1:5])
return sum(abs2, x .- 1)
return sum(abs2, _x .- 1)
end

# use a bunch of broadcasts to test all the adjoints
function loss8(x)
_x = VectorOfArray([x .* i for i in 1:5])
res = copy(_x)
res = res .+ _x
res = res .+ 1
res = res .* _x
res = res .* 2.0
res = res .* res
res = res ./ 2.0
res = res ./ _x
res = 3.0 .- res
res = .-res
res = identity.(Base.literal_pow.(^, res, Val(2)))
res = tanh.(res)
res = res .+ im .* res
res = conj.(res) .+ real.(res) .+ imag.(res) .+ abs2.(res)
return sum(abs2, res)
end

x = float.(6:10)
Expand All @@ -51,3 +71,4 @@ loss(x)
@test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x)
@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x)
@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
38 changes: 38 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,44 @@ push!(testda, [-1, -2, -3, -4])
@test_throws MethodError push!(testda, [-1 -2 -3 -4])
@test_throws MethodError push!(testda, [-1 -2; -3 -4])

# Type inference
@inferred sum(testva)
@inferred sum(VectorOfArray([VectorOfArray([zeros(4,4)])]))
@inferred mapreduce(string, *, testva)

# mapreduce
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@test mapreduce(x -> string(x) * "q", *, testva) == "1q2q3q4q5q6q7q8q9q"

testvb = VectorOfArray([rand(1:10, 3, 3, 3) for _ in 1:4])
arrvb = Array(testvb)
for i in 1:ndims(arrvb)
@test sum(arrvb; dims=i) == sum(testvb; dims=i)
@test prod(arrvb; dims=i) == prod(testvb; dims=i)
@test mapreduce(string, *, arrvb; dims=i) == mapreduce(string, *, testvb; dims=i)
end

# Test when ndims == 1
testvb = VectorOfArray(collect(1.0:0.1:2.0))
arrvb = Array(testvb)
@test sum(arrvb) == sum(testvb)
@test prod(arrvb) == prod(testvb)
@test mapreduce(string, *, arrvb) == mapreduce(string, *, testvb)

# view
testvc = VectorOfArray([rand(1:10, 3, 3) for _ in 1:3])
arrvc = Array(testvc)
for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, : ,:), (:, :, :)]
arr_view = view(arrvc, idxs...)
voa_view = view(testvc, idxs...)
@test size(arr_view) == size(voa_view)
@test all(arr_view .== voa_view)
end

# test stack
@test stack(testva) == [1 4 7; 2 5 8; 3 6 9]
@test stack(testva; dims = 1) == [1 2 3; 4 5 6; 7 8 9]

# convert array from VectorOfArray/DiffEqArray
t = 1:8
recs = [rand(10, 7) for i in 1:8]
Expand Down
8 changes: 8 additions & 0 deletions test/partitions_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
@inferred recursive_one(x)
@inferred recursive_bottom_eltype(x)

# mapreduce
@inferred Union{Int, Float64} sum(x)
@inferred sum(ArrayPartition(ArrayPartition(zeros(4,4))))
@inferred sum(ArrayPartition(ArrayPartition(zeros(4))))
@inferred sum(ArrayPartition(zeros(4,4)))
@inferred mapreduce(string, *, x)
@test mapreduce(i -> string(i) * "q", *, x) == "1q2q3.0q4.0q"

# broadcasting
_scalar_op(y) = y + 1
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:
Expand Down
Loading