diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 26613aee..5eb46b39 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -139,19 +139,19 @@ end view(A, I...), view_adjoint end -function ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) - ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) -end - -function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{ - AbstractArray, AbstractVectorOfArray}) - if eltype(x) <: Number - arr = reshape(x, p.sz) - return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) - elseif eltype(x) <: AbstractArray - return VectorOfArray(x) - end -end +# function ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) +# ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) +# end +# +# function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{ +# AbstractArray, AbstractVectorOfArray}) +# if eltype(x) <: Number +# arr = reshape(x, p.sz) +# return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) +# elseif eltype(x) <: AbstractArray +# return VectorOfArray(x) +# end +# end @adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, y::Union{Zygote.Numeric, AbstractVectorOfArray})