From edb45ef4999db3ce3b37546f5993e613971e9670 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sun, 2 Feb 2025 19:30:43 +0100 Subject: [PATCH] Fix sparse slicing --- phiml/math/_sparse.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/phiml/math/_sparse.py b/phiml/math/_sparse.py index 6f245541..1c7c332c 100644 --- a/phiml/math/_sparse.py +++ b/phiml/math/_sparse.py @@ -424,6 +424,8 @@ def _getitem(self, selection: dict) -> 'Tensor': assert sel.step in (None, 1), f"Only step size 1 supported for sparse indexing but got {sel.step}" start = sel.start or 0 stop = self._dense_shape[dim].size if sel.stop is None else sel.stop + if stop < 0: + stop += self._shape.get_size(dim) keep &= (start <= dim_indices) & (dim_indices < stop) from . import vec indices -= vec('sparse_idx', **{d: start if d == dim else 0 for d in indices.sparse_idx.item_names}) @@ -621,6 +623,8 @@ def _getitem(self, selection: dict) -> 'Tensor': raise NotImplementedError("Slicing not yet supported for batched sparse tensors") start = ptr_sel.start or 0 stop = compressed.volume if ptr_sel.stop is None else ptr_sel.stop + if stop < 0: + stop += compressed.size pointers = pointers[start:stop+1] indices = indices[{instance(indices).name: slice(int(pointers[0]), int(pointers[-1]))}] values = values[{instance(values).name: slice(int(pointers[0]), int(pointers[-1]))}] @@ -639,6 +643,8 @@ def _getitem(self, selection: dict) -> 'Tensor': assert ind_sel.step in (None, 1), f"Only step size 1 supported for sparse indexing but got {ind_sel.step}" start = ind_sel.start or 0 stop = uncompressed.volume if ind_sel.stop is None else ind_sel.stop + if stop < 0: + stop += uncompressed.size keep = (start <= indices) & (indices < stop) from ._ops import where values = where(keep, values, 0)