diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 1106188a..aae2dddc 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -265,6 +265,16 @@ def slice_off(x, *slices: Dict[str, Union[slice, int, str]]): if not slices: return x x_shape = shape(x) + def to_slices(s): + if isinstance(s, Tensor): + assert len(s.shape.channel) == 1, f"Indices tensors must have a single channel dim but got {s}" + dims = s.shape.channel.item_names[0] + indices = s.numpy([..., channel]) + slices = [{d: i for d, i in zip(dims, idx)} for idx in indices] + return slices + assert isinstance(s, dict), f"Not a valid slice: {s}" + return [s] + slices = sum([to_slices(s) for s in slices], []) dims = set().union(*[s.keys() for s in slices]) dims = x_shape.only(dims).names depth = max(len(s) for s in slices)