From 36ea1162efdea2e28fb8a7e2119ee4600c822af3 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Wed, 18 Sep 2024 08:44:42 +0100 Subject: [PATCH] Support ranges with unknown index count in `slice_with` Support slicing tensors with ranges that have a runtime-determined index count in `TensorBase::slice_with`. In such cases the resulting view has a dynamic-rank layout. This allows this method to replace existing usage of `slice_dyn`. --- rten-tensor/src/iterators.rs | 2 +- rten-tensor/src/layout.rs | 25 ++++++++++-- rten-tensor/src/tensor.rs | 2 +- rten-tensor/src/type_num.rs | 79 ++++++++++++++++++++++-------------- src/ops/reduce.rs | 2 +- src/ops/split.rs | 2 +- 6 files changed, 75 insertions(+), 37 deletions(-) diff --git a/rten-tensor/src/iterators.rs b/rten-tensor/src/iterators.rs index 786cfa0c..7bb00e80 100644 --- a/rten-tensor/src/iterators.rs +++ b/rten-tensor/src/iterators.rs @@ -779,7 +779,7 @@ impl<'a, T, L: MutLayout> Iterator for InnerIterDyn<'a, T, L> { fn next(&mut self) -> Option { self.outer_indices.next().map(|idx| { let slice_items = to_slice_items(&idx); - self.view.slice_dyn(slice_items.as_slice()) + self.view.slice_with(slice_items.as_slice()) }) } diff --git a/rten-tensor/src/layout.rs b/rten-tensor/src/layout.rs index ab9f9526..f043464d 100644 --- a/rten-tensor/src/layout.rs +++ b/rten-tensor/src/layout.rs @@ -7,7 +7,7 @@ use crate::errors::{DimensionError, FromDataError, ReshapeError, SliceError}; use crate::index_iterator::{DynIndices, NdIndices}; use crate::overlap::{is_contiguous, may_have_internal_overlap}; use crate::slice_range::{IntoSliceItems, SliceItem}; -use crate::type_num::{ConstUInt, U0, U1, U2, U3, U4, U5}; +use crate::type_num::{OptionalUInt, Unknown, U0, U1, U2, U3, U4, U5}; /// Return true if `permutation` is a valid permutation of dimensions for /// a tensor of rank `ndim`. @@ -1372,7 +1372,7 @@ impl_remove_dim!(5, 4); /// /// `R` is the type of the slice range. `IdxCount` is a marker type indicating /// the number of items in `R` that are indices, as opposed to ranges. -pub trait SliceWith { +pub trait SliceWith { /// The layout produced after slicing. type Layout: Layout; @@ -1384,7 +1384,7 @@ pub trait SliceWith { fn slice_with(&self, range: R) -> (Range, Self::Layout); } -impl SliceWith for DynLayout { +impl SliceWith for L { type Layout = DynLayout; fn slice_with(&self, range: R) -> (Range, Self::Layout) { @@ -1400,6 +1400,25 @@ impl SliceWith for NdLayout { } } +macro_rules! impl_slice_with_dynlayout { + ($range_ndim:ty) => { + impl SliceWith for DynLayout { + type Layout = DynLayout; + + fn slice_with(&self, range: R) -> (Range, Self::Layout) { + self.slice_dyn(range.into_slice_items().as_ref()) + } + } + }; +} + +impl_slice_with_dynlayout!(U0); +impl_slice_with_dynlayout!(U1); +impl_slice_with_dynlayout!(U2); +impl_slice_with_dynlayout!(U3); +impl_slice_with_dynlayout!(U4); +impl_slice_with_dynlayout!(U5); + macro_rules! impl_slice_with { ($ndim:literal, $range_ndim:ty, $out_ndim:literal) => { impl SliceWith for NdLayout<$ndim> { diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs index 1196b0f0..ba96b0e0 100644 --- a/rten-tensor/src/tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -3544,7 +3544,7 @@ mod tests { // Slice dynamic-rank array. The rank of the slice is dynamic. let mut data = Tensor::from([[[1, 2, 3], [4, 5, 6]]]); let mut row = data.slice_with_mut((0, 0)); - row[[0]] = 10; + row[[0usize]] = 10; assert_eq!(row.shape(), [3usize]); assert_eq!(row.data().unwrap(), &[10, 2, 3]); } diff --git a/rten-tensor/src/type_num.rs b/rten-tensor/src/type_num.rs index e1da4b2c..9030279c 100644 --- a/rten-tensor/src/type_num.rs +++ b/rten-tensor/src/type_num.rs @@ -6,6 +6,9 @@ use std::ops::{Range, RangeFrom, RangeFull, RangeTo}; +/// Type representing an integer whose value is unknown at compile time. +pub struct Unknown {} + /// Type representing the integer value 0. pub struct U0 {} @@ -25,29 +28,34 @@ pub struct U4 {} pub struct U5 {} /// Trait providing the integer value of a `U` type (eg. [`U0`]). -pub trait ConstUInt { - const VALUE: usize; +/// +/// The value can be unknown to represent numbers that are known only at +/// runtime. +pub trait OptionalUInt { + const VALUE: Option; } macro_rules! impl_const_int { - ($type:ty, $val:literal) => { - impl ConstUInt for $type { - const VALUE: usize = $val; + ($type:ty, $val:expr) => { + impl OptionalUInt for $type { + const VALUE: Option = $val; } }; } -impl_const_int!(U0, 0); -impl_const_int!(U1, 1); -impl_const_int!(U2, 2); -impl_const_int!(U3, 3); -impl_const_int!(U4, 4); -impl_const_int!(U5, 5); -/// Trait that computes the sum of [`ConstUInt`] types. +impl_const_int!(Unknown, None); +impl_const_int!(U0, Some(0)); +impl_const_int!(U1, Some(1)); +impl_const_int!(U2, Some(2)); +impl_const_int!(U3, Some(3)); +impl_const_int!(U4, Some(4)); +impl_const_int!(U5, Some(5)); + +/// Trait that computes the sum of [`OptionalUInt`] types. /// /// It is implemented for 2-tuples, as well as arrays of either `U0` or `U1`. pub trait Add { - type Result: ConstUInt; + type Result: OptionalUInt; } macro_rules! impl_add { @@ -105,7 +113,10 @@ impl_add_ones!(5, U5); pub trait IsIndex { /// Associated type that is either [`U0`] or [`U1`] indicating whether this /// type is an index. - type IsIndex: ConstUInt; + /// + /// The value can also be [`Unknown`] to indicate a value that may be either + /// an index or a range. + type IsIndex: OptionalUInt; } macro_rules! impl_is_index { @@ -147,16 +158,17 @@ impl IsIndex for RangeFrom { /// /// ``` /// use rten_tensor::type_num::IndexCount; -/// assert_eq!((.., 1..2).index_count(), 0); -/// assert_eq!((0, 1..2).index_count(), 1); -/// assert_eq!((0, 1).index_count(), 2); +/// assert_eq!((.., 1..2).index_count(), Some(0)); +/// assert_eq!((0, 1..2).index_count(), Some(1)); +/// assert_eq!((0, 1).index_count(), Some(2)); +/// assert_eq!([0, 1].as_slice().index_count(), None); /// ``` pub trait IndexCount { /// Type representing the count value. - type Count: ConstUInt; + type Count: OptionalUInt; /// Returns [`Count`](IndexCount::Count) as a numeric value. - fn index_count(&self) -> usize { + fn index_count(&self) -> Option { Self::Count::VALUE } } @@ -224,6 +236,10 @@ where type Count = <[T::IsIndex; N] as Add>::Result; } +impl<'a, T> IndexCount for &'a [T] { + type Count = Unknown; +} + #[cfg(test)] mod tests { use super::IndexCount; @@ -231,20 +247,23 @@ mod tests { #[test] fn test_index_count() { // Single values - assert_eq!((0).index_count(), 1); - assert_eq!((..).index_count(), 0); - assert_eq!((..1).index_count(), 0); - assert_eq!((1..).index_count(), 0); - assert_eq!((1..2).index_count(), 0); + assert_eq!((0).index_count(), Some(1)); + assert_eq!((..).index_count(), Some(0)); + assert_eq!((..1).index_count(), Some(0)); + assert_eq!((1..).index_count(), Some(0)); + assert_eq!((1..2).index_count(), Some(0)); // Tuples - assert_eq!((0,).index_count(), 1); - assert_eq!((0, ..).index_count(), 1); - assert_eq!((0, .., 2).index_count(), 2); - assert_eq!((0, .., 2, ..).index_count(), 2); - assert_eq!((0, .., 2, .., 3).index_count(), 3); + assert_eq!((0,).index_count(), Some(1)); + assert_eq!((0, ..).index_count(), Some(1)); + assert_eq!((0, .., 2).index_count(), Some(2)); + assert_eq!((0, .., 2, ..).index_count(), Some(2)); + assert_eq!((0, .., 2, .., 3).index_count(), Some(3)); // Arrays - assert_eq!([1, 2, 3].index_count(), 3); + assert_eq!([1, 2, 3].index_count(), Some(3)); + + // Slices + assert_eq!([1, 2, 3].as_slice().index_count(), None); } } diff --git a/src/ops/reduce.rs b/src/ops/reduce.rs index 48f2883f..9e4105b2 100644 --- a/src/ops/reduce.rs +++ b/src/ops/reduce.rs @@ -352,7 +352,7 @@ fn reduce>( SliceItem::Index(idx as isize) } })); - let slice = input.slice_dyn(inner_range.as_slice()); + let slice = input.slice_with(inner_range.as_slice()); let reduced = reducer.reduce(slice.iter().copied()); reduced_data.push(reduced); } diff --git a/src/ops/split.rs b/src/ops/split.rs index 2d96a964..c6790e87 100644 --- a/src/ops/split.rs +++ b/src/ops/split.rs @@ -40,7 +40,7 @@ pub fn split( split_start += split_size; - input.slice_dyn(slice_range.as_slice()).to_tensor_in(pool) + input.slice_with(slice_range.as_slice()).to_tensor_in(pool) }) .collect();