Skip to content

Commit

Permalink
Support simple index tensors in slice_off()
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl authored and holl- committed Jan 8, 2025
1 parent ea6cd2d commit 7c6e444
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,16 @@ def slice_off(x, *slices: Dict[str, Union[slice, int, str]]):
if not slices:
return x
x_shape = shape(x)
def to_slices(s):
if isinstance(s, Tensor):
assert len(s.shape.channel) == 1, f"Indices tensors must have a single channel dim but got {s}"
dims = s.shape.channel.item_names[0]
indices = s.numpy([..., channel])
slices = [{d: i for d, i in zip(dims, idx)} for idx in indices]
return slices
assert isinstance(s, dict), f"Not a valid slice: {s}"
return [s]
slices = sum([to_slices(s) for s in slices], [])
dims = set().union(*[s.keys() for s in slices])
dims = x_shape.only(dims).names
depth = max(len(s) for s in slices)
Expand Down

0 comments on commit 7c6e444

Please sign in to comment.