From 1f9b57774bda4f96121f5a226e40941e3da1b73a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 5 Jan 2024 17:43:07 +0530 Subject: [PATCH] feat: add and test adjoints for broadcast arithmetic on VoA --- ext/RecursiveArrayToolsZygoteExt.jl | 112 +++++++++++++++++++++++++++- test/adjoints.jl | 23 +++++- 2 files changed, 132 insertions(+), 3 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index ead05de1..ac2f1622 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -100,6 +100,11 @@ end end end +@adjoint function Base.copy(u::VectorOfArray) + copy(u), + y -> (copy(y),) +end + @adjoint function DiffEqArray(u, t) DiffEqArray(u, t), y -> begin @@ -117,19 +122,122 @@ end A.x, literal_ArrayPartition_x_adjoint end -@adjoint function Array(VA::AbstractVectorOfArray) +@adjoint function Base.Array(VA::AbstractVectorOfArray) Array(VA), y -> (Array(y),) end +@adjoint function Base.view(A::AbstractVectorOfArray, I...) + view(A, I...), + y -> (view(y, I...), ntuple(_ -> nothing, length(I))...) +end ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) -function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x) +function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{AbstractArray,AbstractVectorOfArray}) arr = reshape(x, p.sz) return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) end +@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, y::Union{Zygote.Numeric, AbstractVectorOfArray}) + broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...) +end +@adjoint function Broadcast.broadcasted(::typeof(+), x::Zygote.Numeric, y::AbstractVectorOfArray) + broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...) +end + +_minus(Δ) = .-Δ +_minus(::Nothing) = nothing + +@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric}) + x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ))) +end +@adjoint function Broadcast.broadcasted(::typeof(*), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric}) + ( + x.*y, + Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x))) + ) +end +@adjoint function Broadcast.broadcasted(::typeof(/), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric}) + res = x ./ y + res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y))) +end +@adjoint function Broadcast.broadcasted(::typeof(-), x::Zygote.Numeric, y::AbstractVectorOfArray) + x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ))) +end +@adjoint function Broadcast.broadcasted(::typeof(*), x::Zygote.Numeric, y::AbstractVectorOfArray) + ( + x.*y, + Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x))) + ) +end +@adjoint function Broadcast.broadcasted(::typeof(/), x::Zygote.Numeric, y::AbstractVectorOfArray) + res = x ./ y + res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y))) +end +@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray) + .-x, Δ -> (nothing, _minus(Δ)) +end + +@adjoint function Broadcast.broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::AbstractVectorOfArray, exp::Val{p}) where p + y = Base.literal_pow.(^, x, exp) + y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing) +end + +@adjoint Broadcast.broadcasted(::typeof(identity), x::AbstractVectorOfArray) = x, Δ -> (nothing, Δ) + +@adjoint function Broadcast.broadcasted(::typeof(tanh), x::AbstractVectorOfArray) + y = tanh.(x) + y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2)) +end + +@adjoint Broadcast.broadcasted(::typeof(conj), x::AbstractVectorOfArray) = + conj.(x), z̄ -> (nothing, conj.(z̄)) + +@adjoint Broadcast.broadcasted(::typeof(real), x::AbstractVectorOfArray) = + real.(x), z̄ -> (nothing, real.(z̄)) + +@adjoint Broadcast.broadcasted(::typeof(imag), x::AbstractVectorOfArray) = + imag.(x), z̄ -> (nothing, im .* real.(z̄)) + +@adjoint Broadcast.broadcasted(::typeof(abs2), x::AbstractVectorOfArray) = + abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) + +@adjoint function Broadcast.broadcasted(::typeof(+), a::AbstractVectorOfArray{<:Number}, b::Bool) + y = b === false ? a : a .+ b + y, Δ -> (nothing, Δ, nothing) +end +@adjoint function Broadcast.broadcasted(::typeof(+), b::Bool, a::AbstractVectorOfArray{<:Number}) + y = b === false ? a : b .+ a + y, Δ -> (nothing, nothing, Δ) +end + +@adjoint function Broadcast.broadcasted(::typeof(-), a::AbstractVectorOfArray{<:Number}, b::Bool) + y = b === false ? a : a .- b + y, Δ -> (nothing, Δ, nothing) +end +@adjoint function Broadcast.broadcasted(::typeof(-), b::Bool, a::AbstractVectorOfArray{<:Number}) + b .- a, Δ -> (nothing, nothing, .-Δ) +end + +@adjoint function Broadcast.broadcasted(::typeof(*), a::AbstractVectorOfArray{<:Number}, b::Bool) + if b === false + zero(a), Δ -> (nothing, zero(Δ), nothing) + else + a, Δ -> (nothing, Δ, nothing) + end +end +@adjoint function Broadcast.broadcasted(::typeof(*), b::Bool, a::AbstractVectorOfArray{<:Number}) + if b === false + zero(a), Δ -> (nothing, nothing, zero(Δ)) + else + a, Δ -> (nothing, nothing, Δ) + end +end + +@adjoint Broadcast.broadcasted(::Type{T}, x::AbstractVectorOfArray) where {T<:Number} = + T.(x), ȳ -> (nothing, Zygote._project(x, ȳ),) + function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄) N = ndims(x̄) if length(x) == length(x̄) diff --git a/test/adjoints.jl b/test/adjoints.jl index e06035af..c657dcf9 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -39,7 +39,27 @@ end function loss7(x) _x = VectorOfArray([x .* i for i in 1:5]) - return sum(abs2, x .- 1) + return sum(abs2, _x .- 1) +end + +# use a bunch of broadcasts to test all the adjoints +function loss8(x) + _x = VectorOfArray([x .* i for i in 1:5]) + res = copy(_x) + res = res .+ _x + res = res .+ 1 + res = res .* _x + res = res .* 2.0 + res = res .* res + res = res ./ 2.0 + res = res ./ _x + res = 3.0 .- res + res = .-res + res = identity.(Base.literal_pow.(^, res, Val(2))) + res = tanh.(res) + res = res .+ im .* res + res = conj.(res) .+ real.(res) .+ imag.(res) .+ abs2.(res) + return sum(abs2, res) end x = float.(6:10) @@ -51,3 +71,4 @@ loss(x) @test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x) @test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x) @test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x) +@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)