From c8afcd0fdfdeeca3e3c7124a1ffa7d04d80e6adc Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 16 May 2024 15:18:24 +0000 Subject: [PATCH 1/2] chore: remove projection rule --- ext/RecursiveArrayToolsZygoteExt.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 26613aee..5eb46b39 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -139,19 +139,19 @@ end view(A, I...), view_adjoint end -function ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) - ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) -end - -function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{ - AbstractArray, AbstractVectorOfArray}) - if eltype(x) <: Number - arr = reshape(x, p.sz) - return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) - elseif eltype(x) <: AbstractArray - return VectorOfArray(x) - end -end +# function ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) +# ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) +# end +# +# function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{ +# AbstractArray, AbstractVectorOfArray}) +# if eltype(x) <: Number +# arr = reshape(x, p.sz) +# return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) +# elseif eltype(x) <: AbstractArray +# return VectorOfArray(x) +# end +# end @adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, y::Union{Zygote.Numeric, AbstractVectorOfArray}) From c51d5d180a981a365b737e14b3827e05707211d2 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 16 May 2024 13:13:13 -0400 Subject: [PATCH 2/2] Update ext/RecursiveArrayToolsZygoteExt.jl --- ext/RecursiveArrayToolsZygoteExt.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 5eb46b39..9b253a4f 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -139,19 +139,6 @@ end view(A, I...), view_adjoint end -# function ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) -# ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) -# end -# -# function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{ -# AbstractArray, AbstractVectorOfArray}) -# if eltype(x) <: Number -# arr = reshape(x, p.sz) -# return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) -# elseif eltype(x) <: AbstractArray -# return VectorOfArray(x) -# end -# end @adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, y::Union{Zygote.Numeric, AbstractVectorOfArray})