Skip to content

Commit

Permalink
feat: add and test adjoints for broadcast arithmetic on VoA
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 5, 2024
1 parent 1ecc966 commit 1f9b577
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 3 deletions.
112 changes: 110 additions & 2 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ end
end
end

@adjoint function Base.copy(u::VectorOfArray)

Check warning on line 103 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L103

Added line #L103 was not covered by tests
copy(u),
y -> (copy(y),)
end

@adjoint function DiffEqArray(u, t)
DiffEqArray(u, t),
y -> begin
Expand All @@ -117,19 +122,122 @@ end
A.x, literal_ArrayPartition_x_adjoint
end

@adjoint function Array(VA::AbstractVectorOfArray)
@adjoint function Base.Array(VA::AbstractVectorOfArray)

Check warning on line 125 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L125

Added line #L125 was not covered by tests
Array(VA),
y -> (Array(y),)
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)

Check warning on line 130 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L130

Added line #L130 was not covered by tests
view(A, I...),
y -> (view(y, I...), ntuple(_ -> nothing, length(I))...)
end

ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))

function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x)
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{AbstractArray,AbstractVectorOfArray})
arr = reshape(x, p.sz)
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
end

@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, y::Union{Zygote.Numeric, AbstractVectorOfArray})

Check warning on line 142 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L142

Added line #L142 was not covered by tests
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)
end
@adjoint function Broadcast.broadcasted(::typeof(+), x::Zygote.Numeric, y::AbstractVectorOfArray)
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)

Check warning on line 146 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L145-L146

Added lines #L145 - L146 were not covered by tests
end

_minus(Δ) = .-Δ
_minus(::Nothing) = nothing

Check warning on line 150 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L150

Added line #L150 was not covered by tests

@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})

Check warning on line 152 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L152

Added line #L152 was not covered by tests
x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ)))
end
@adjoint function Broadcast.broadcasted(::typeof(*), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})

Check warning on line 155 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L155

Added line #L155 was not covered by tests
(
x.*y,
Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x)))
)
end
@adjoint function Broadcast.broadcasted(::typeof(/), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})

Check warning on line 161 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L161

Added line #L161 was not covered by tests
res = x ./ y
res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y)))
end
@adjoint function Broadcast.broadcasted(::typeof(-), x::Zygote.Numeric, y::AbstractVectorOfArray)

Check warning on line 165 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L165

Added line #L165 was not covered by tests
x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ)))
end
@adjoint function Broadcast.broadcasted(::typeof(*), x::Zygote.Numeric, y::AbstractVectorOfArray)

Check warning on line 168 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L168

Added line #L168 was not covered by tests
(
x.*y,
Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x)))
)
end
@adjoint function Broadcast.broadcasted(::typeof(/), x::Zygote.Numeric, y::AbstractVectorOfArray)
res = x ./ y
res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y)))

Check warning on line 176 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L174-L176

Added lines #L174 - L176 were not covered by tests
end
@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray)

Check warning on line 178 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L178

Added line #L178 was not covered by tests
.-x, Δ -> (nothing, _minus(Δ))
end

@adjoint function Broadcast.broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::AbstractVectorOfArray, exp::Val{p}) where p

Check warning on line 182 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L182

Added line #L182 was not covered by tests
y = Base.literal_pow.(^, x, exp)
y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing)
end

@adjoint Broadcast.broadcasted(::typeof(identity), x::AbstractVectorOfArray) = x, Δ -> (nothing, Δ)

@adjoint function Broadcast.broadcasted(::typeof(tanh), x::AbstractVectorOfArray)

Check warning on line 189 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L189

Added line #L189 was not covered by tests
y = tanh.(x)
y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2))
end

@adjoint Broadcast.broadcasted(::typeof(conj), x::AbstractVectorOfArray) =
conj.(x), z̄ -> (nothing, conj.(z̄))

@adjoint Broadcast.broadcasted(::typeof(real), x::AbstractVectorOfArray) =
real.(x), z̄ -> (nothing, real.(z̄))

@adjoint Broadcast.broadcasted(::typeof(imag), x::AbstractVectorOfArray) =
imag.(x), z̄ -> (nothing, im .* real.(z̄))

@adjoint Broadcast.broadcasted(::typeof(abs2), x::AbstractVectorOfArray) =
abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x)

@adjoint function Broadcast.broadcasted(::typeof(+), a::AbstractVectorOfArray{<:Number}, b::Bool)
y = b === false ? a : a .+ b
y, Δ -> (nothing, Δ, nothing)

Check warning on line 208 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L206-L208

Added lines #L206 - L208 were not covered by tests
end
@adjoint function Broadcast.broadcasted(::typeof(+), b::Bool, a::AbstractVectorOfArray{<:Number})
y = b === false ? a : b .+ a
y, Δ -> (nothing, nothing, Δ)

Check warning on line 212 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L210-L212

Added lines #L210 - L212 were not covered by tests
end

@adjoint function Broadcast.broadcasted(::typeof(-), a::AbstractVectorOfArray{<:Number}, b::Bool)
y = b === false ? a : a .- b
y, Δ -> (nothing, Δ, nothing)

Check warning on line 217 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L215-L217

Added lines #L215 - L217 were not covered by tests
end
@adjoint function Broadcast.broadcasted(::typeof(-), b::Bool, a::AbstractVectorOfArray{<:Number})
b .- a, Δ -> (nothing, nothing, .-Δ)

Check warning on line 220 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L219-L220

Added lines #L219 - L220 were not covered by tests
end

@adjoint function Broadcast.broadcasted(::typeof(*), a::AbstractVectorOfArray{<:Number}, b::Bool)
if b === false
zero(a), Δ -> (nothing, zero(Δ), nothing)

Check warning on line 225 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L223-L225

Added lines #L223 - L225 were not covered by tests
else
a, Δ -> (nothing, Δ, nothing)

Check warning on line 227 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L227

Added line #L227 was not covered by tests
end
end
@adjoint function Broadcast.broadcasted(::typeof(*), b::Bool, a::AbstractVectorOfArray{<:Number})
if b === false
zero(a), Δ -> (nothing, nothing, zero(Δ))

Check warning on line 232 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L230-L232

Added lines #L230 - L232 were not covered by tests
else
a, Δ -> (nothing, nothing, Δ)

Check warning on line 234 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L234

Added line #L234 was not covered by tests
end
end

@adjoint Broadcast.broadcasted(::Type{T}, x::AbstractVectorOfArray) where {T<:Number} =
T.(x), ȳ -> (nothing, Zygote._project(x, ȳ),)

Check warning on line 239 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L238-L239

Added lines #L238 - L239 were not covered by tests

function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
N = ndims(x̄)
if length(x) == length(x̄)
Expand Down
23 changes: 22 additions & 1 deletion test/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,27 @@ end

function loss7(x)
_x = VectorOfArray([x .* i for i in 1:5])
return sum(abs2, x .- 1)
return sum(abs2, _x .- 1)
end

# use a bunch of broadcasts to test all the adjoints
function loss8(x)
_x = VectorOfArray([x .* i for i in 1:5])
res = copy(_x)
res = res .+ _x
res = res .+ 1
res = res .* _x
res = res .* 2.0
res = res .* res
res = res ./ 2.0
res = res ./ _x
res = 3.0 .- res
res = .-res
res = identity.(Base.literal_pow.(^, res, Val(2)))
res = tanh.(res)
res = res .+ im .* res
res = conj.(res) .+ real.(res) .+ imag.(res) .+ abs2.(res)
return sum(abs2, res)
end

x = float.(6:10)
Expand All @@ -51,3 +71,4 @@ loss(x)
@test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x)
@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x)
@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)

0 comments on commit 1f9b577

Please sign in to comment.