Skip to content

Commit

Permalink
fix: ChainRulesCore.ProjectTo implementation for VoA
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 5, 2024
1 parent 433a780 commit 8e5f902
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ end

ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))

function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x)
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{AbstractArray,AbstractVectorOfArray})
arr = reshape(x, p.sz)
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
end
Expand Down

0 comments on commit 8e5f902

Please sign in to comment.