Skip to content

Commit

Permalink
Merge pull request #363 from robertknight/slice-with-dyn
Browse files Browse the repository at this point in the history
Support ranges with unknown index count in `slice_with`
  • Loading branch information
robertknight authored Sep 18, 2024
2 parents a57f594 + 36ea116 commit a98f09d
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 37 deletions.
2 changes: 1 addition & 1 deletion rten-tensor/src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ impl<'a, T, L: MutLayout> Iterator for InnerIterDyn<'a, T, L> {
fn next(&mut self) -> Option<Self::Item> {
self.outer_indices.next().map(|idx| {
let slice_items = to_slice_items(&idx);
self.view.slice_dyn(slice_items.as_slice())
self.view.slice_with(slice_items.as_slice())
})
}

Expand Down
25 changes: 22 additions & 3 deletions rten-tensor/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::errors::{DimensionError, FromDataError, ReshapeError, SliceError};
use crate::index_iterator::{DynIndices, NdIndices};
use crate::overlap::{is_contiguous, may_have_internal_overlap};
use crate::slice_range::{IntoSliceItems, SliceItem};
use crate::type_num::{ConstUInt, U0, U1, U2, U3, U4, U5};
use crate::type_num::{OptionalUInt, Unknown, U0, U1, U2, U3, U4, U5};

/// Return true if `permutation` is a valid permutation of dimensions for
/// a tensor of rank `ndim`.
Expand Down Expand Up @@ -1372,7 +1372,7 @@ impl_remove_dim!(5, 4);
///
/// `R` is the type of the slice range. `IdxCount` is a marker type indicating
/// the number of items in `R` that are indices, as opposed to ranges.
pub trait SliceWith<R: IntoSliceItems, IdxCount: ConstUInt> {
pub trait SliceWith<R: IntoSliceItems, IdxCount: OptionalUInt> {
/// The layout produced after slicing.
type Layout: Layout;

Expand All @@ -1384,7 +1384,7 @@ pub trait SliceWith<R: IntoSliceItems, IdxCount: ConstUInt> {
fn slice_with(&self, range: R) -> (Range<usize>, Self::Layout);
}

impl<R: IntoSliceItems, I: ConstUInt> SliceWith<R, I> for DynLayout {
impl<R: IntoSliceItems, L: MutLayout> SliceWith<R, Unknown> for L {
type Layout = DynLayout;

fn slice_with(&self, range: R) -> (Range<usize>, Self::Layout) {
Expand All @@ -1400,6 +1400,25 @@ impl<R: IntoSliceItems, const N: usize> SliceWith<R, U0> for NdLayout<N> {
}
}

macro_rules! impl_slice_with_dynlayout {
($range_ndim:ty) => {
impl<R: IntoSliceItems> SliceWith<R, $range_ndim> for DynLayout {
type Layout = DynLayout;

fn slice_with(&self, range: R) -> (Range<usize>, Self::Layout) {
self.slice_dyn(range.into_slice_items().as_ref())
}
}
};
}

impl_slice_with_dynlayout!(U0);
impl_slice_with_dynlayout!(U1);
impl_slice_with_dynlayout!(U2);
impl_slice_with_dynlayout!(U3);
impl_slice_with_dynlayout!(U4);
impl_slice_with_dynlayout!(U5);

macro_rules! impl_slice_with {
($ndim:literal, $range_ndim:ty, $out_ndim:literal) => {
impl<R: IntoSliceItems> SliceWith<R, $range_ndim> for NdLayout<$ndim> {
Expand Down
2 changes: 1 addition & 1 deletion rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3544,7 +3544,7 @@ mod tests {
// Slice dynamic-rank array. The rank of the slice is dynamic.
let mut data = Tensor::from([[[1, 2, 3], [4, 5, 6]]]);
let mut row = data.slice_with_mut((0, 0));
row[[0]] = 10;
row[[0usize]] = 10;
assert_eq!(row.shape(), [3usize]);
assert_eq!(row.data().unwrap(), &[10, 2, 3]);
}
Expand Down
79 changes: 49 additions & 30 deletions rten-tensor/src/type_num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

use std::ops::{Range, RangeFrom, RangeFull, RangeTo};

/// Type representing an integer whose value is unknown at compile time.
pub struct Unknown {}

/// Type representing the integer value 0.
pub struct U0 {}

Expand All @@ -25,29 +28,34 @@ pub struct U4 {}
pub struct U5 {}

/// Trait providing the integer value of a `U<N>` type (eg. [`U0`]).
pub trait ConstUInt {
const VALUE: usize;
///
/// The value can be unknown to represent numbers that are known only at
/// runtime.
pub trait OptionalUInt {
const VALUE: Option<usize>;
}

macro_rules! impl_const_int {
($type:ty, $val:literal) => {
impl ConstUInt for $type {
const VALUE: usize = $val;
($type:ty, $val:expr) => {
impl OptionalUInt for $type {
const VALUE: Option<usize> = $val;
}
};
}
impl_const_int!(U0, 0);
impl_const_int!(U1, 1);
impl_const_int!(U2, 2);
impl_const_int!(U3, 3);
impl_const_int!(U4, 4);
impl_const_int!(U5, 5);

/// Trait that computes the sum of [`ConstUInt`] types.
impl_const_int!(Unknown, None);
impl_const_int!(U0, Some(0));
impl_const_int!(U1, Some(1));
impl_const_int!(U2, Some(2));
impl_const_int!(U3, Some(3));
impl_const_int!(U4, Some(4));
impl_const_int!(U5, Some(5));

/// Trait that computes the sum of [`OptionalUInt`] types.
///
/// It is implemented for 2-tuples, as well as arrays of either `U0` or `U1`.
pub trait Add {
type Result: ConstUInt;
type Result: OptionalUInt;
}

macro_rules! impl_add {
Expand Down Expand Up @@ -105,7 +113,10 @@ impl_add_ones!(5, U5);
pub trait IsIndex {
/// Associated type that is either [`U0`] or [`U1`] indicating whether this
/// type is an index.
type IsIndex: ConstUInt;
///
/// The value can also be [`Unknown`] to indicate a value that may be either
/// an index or a range.
type IsIndex: OptionalUInt;
}

macro_rules! impl_is_index {
Expand Down Expand Up @@ -147,16 +158,17 @@ impl<T> IsIndex for RangeFrom<T> {
///
/// ```
/// use rten_tensor::type_num::IndexCount;
/// assert_eq!((.., 1..2).index_count(), 0);
/// assert_eq!((0, 1..2).index_count(), 1);
/// assert_eq!((0, 1).index_count(), 2);
/// assert_eq!((.., 1..2).index_count(), Some(0));
/// assert_eq!((0, 1..2).index_count(), Some(1));
/// assert_eq!((0, 1).index_count(), Some(2));
/// assert_eq!([0, 1].as_slice().index_count(), None);
/// ```
pub trait IndexCount {
/// Type representing the count value.
type Count: ConstUInt;
type Count: OptionalUInt;

/// Returns [`Count`](IndexCount::Count) as a numeric value.
fn index_count(&self) -> usize {
fn index_count(&self) -> Option<usize> {
Self::Count::VALUE
}
}
Expand Down Expand Up @@ -224,27 +236,34 @@ where
type Count = <[T::IsIndex; N] as Add>::Result;
}

impl<'a, T> IndexCount for &'a [T] {
type Count = Unknown;
}

#[cfg(test)]
mod tests {
use super::IndexCount;

#[test]
fn test_index_count() {
// Single values
assert_eq!((0).index_count(), 1);
assert_eq!((..).index_count(), 0);
assert_eq!((..1).index_count(), 0);
assert_eq!((1..).index_count(), 0);
assert_eq!((1..2).index_count(), 0);
assert_eq!((0).index_count(), Some(1));
assert_eq!((..).index_count(), Some(0));
assert_eq!((..1).index_count(), Some(0));
assert_eq!((1..).index_count(), Some(0));
assert_eq!((1..2).index_count(), Some(0));

// Tuples
assert_eq!((0,).index_count(), 1);
assert_eq!((0, ..).index_count(), 1);
assert_eq!((0, .., 2).index_count(), 2);
assert_eq!((0, .., 2, ..).index_count(), 2);
assert_eq!((0, .., 2, .., 3).index_count(), 3);
assert_eq!((0,).index_count(), Some(1));
assert_eq!((0, ..).index_count(), Some(1));
assert_eq!((0, .., 2).index_count(), Some(2));
assert_eq!((0, .., 2, ..).index_count(), Some(2));
assert_eq!((0, .., 2, .., 3).index_count(), Some(3));

// Arrays
assert_eq!([1, 2, 3].index_count(), 3);
assert_eq!([1, 2, 3].index_count(), Some(3));

// Slices
assert_eq!([1, 2, 3].as_slice().index_count(), None);
}
}
2 changes: 1 addition & 1 deletion src/ops/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ fn reduce<T: Copy, R: Reducer<T>>(
SliceItem::Index(idx as isize)
}
}));
let slice = input.slice_dyn(inner_range.as_slice());
let slice = input.slice_with(inner_range.as_slice());
let reduced = reducer.reduce(slice.iter().copied());
reduced_data.push(reduced);
}
Expand Down
2 changes: 1 addition & 1 deletion src/ops/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub fn split<T: Copy>(

split_start += split_size;

input.slice_dyn(slice_range.as_slice()).to_tensor_in(pool)
input.slice_with(slice_range.as_slice()).to_tensor_in(pool)
})
.collect();

Expand Down

0 comments on commit a98f09d

Please sign in to comment.