Skip to content

Commit

Permalink
Fix sparse slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Feb 2, 2025
1 parent e6cbfcc commit edb45ef
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ def _getitem(self, selection: dict) -> 'Tensor':
assert sel.step in (None, 1), f"Only step size 1 supported for sparse indexing but got {sel.step}"
start = sel.start or 0
stop = self._dense_shape[dim].size if sel.stop is None else sel.stop
if stop < 0:
stop += self._shape.get_size(dim)
keep &= (start <= dim_indices) & (dim_indices < stop)
from . import vec
indices -= vec('sparse_idx', **{d: start if d == dim else 0 for d in indices.sparse_idx.item_names})
Expand Down Expand Up @@ -621,6 +623,8 @@ def _getitem(self, selection: dict) -> 'Tensor':
raise NotImplementedError("Slicing not yet supported for batched sparse tensors")
start = ptr_sel.start or 0
stop = compressed.volume if ptr_sel.stop is None else ptr_sel.stop
if stop < 0:
stop += compressed.size
pointers = pointers[start:stop+1]
indices = indices[{instance(indices).name: slice(int(pointers[0]), int(pointers[-1]))}]
values = values[{instance(values).name: slice(int(pointers[0]), int(pointers[-1]))}]
Expand All @@ -639,6 +643,8 @@ def _getitem(self, selection: dict) -> 'Tensor':
assert ind_sel.step in (None, 1), f"Only step size 1 supported for sparse indexing but got {ind_sel.step}"
start = ind_sel.start or 0
stop = uncompressed.volume if ind_sel.stop is None else ind_sel.stop
if stop < 0:
stop += uncompressed.size
keep = (start <= indices) & (indices < stop)
from ._ops import where
values = where(keep, values, 0)
Expand Down

0 comments on commit edb45ef

Please sign in to comment.