diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 96727f7..f869a9a 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -137,7 +137,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) - y1 = ncp.sum(y1, axis=0) + y1 = ncp.sum(ncp.vstack(y1), axis=0) y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) return y