Skip to content

Commit

Permalink
chore: remove projection rule
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed May 16, 2024
1 parent 53cd1ae commit c8afcd0
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,19 @@ end
view(A, I...), view_adjoint
end

function ChainRulesCore.ProjectTo(a::AbstractVectorOfArray)
ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
end

function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{
AbstractArray, AbstractVectorOfArray})
if eltype(x) <: Number
arr = reshape(x, p.sz)
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
elseif eltype(x) <: AbstractArray
return VectorOfArray(x)
end
end
# function ChainRulesCore.ProjectTo(a::AbstractVectorOfArray)
# ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
# end
#
# function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{
# AbstractArray, AbstractVectorOfArray})
# if eltype(x) <: Number
# arr = reshape(x, p.sz)
# return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
# elseif eltype(x) <: AbstractArray
# return VectorOfArray(x)
# end
# end

@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray,
y::Union{Zygote.Numeric, AbstractVectorOfArray})
Expand Down

0 comments on commit c8afcd0

Please sign in to comment.