Skip to content

Commit

Permalink
Merge pull request #360 from robertknight/tensor-slice-with
Browse files Browse the repository at this point in the history
Add `TensorBase::slice_with` method
  • Loading branch information
robertknight authored Sep 17, 2024
2 parents 4079cc8 + 1ca14e0 commit 9b5d7fe
Show file tree
Hide file tree
Showing 15 changed files with 468 additions and 55 deletions.
2 changes: 1 addition & 1 deletion rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ impl<'a> Generator<'a> {

// Sample output token.
let logits: NdTensor<f32, 3> = outputs.remove(0).try_into().map_err(wrap_error)?;
let next_id = self.sampler.sample(logits.slice::<1, _>((0, -1)));
let next_id = self.sampler.sample(logits.slice_with((0, -1)));

// Update the self-attention key-value cache.
//
Expand Down
6 changes: 1 addition & 5 deletions rten-generate/src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,7 @@ impl Sampler for TopKSampler {
let topk_index = multinomial(&mut self.rng.borrow_mut(), probs.nd_view())
.expect("probs should be non-empty and sum to 1");

let token_id = topk_indices
.slice::<0, _>(topk_index)
.item()
.copied()
.unwrap();
let token_id = topk_indices.slice_with(topk_index).item().copied().unwrap();
token_id as TokenId
}
}
Expand Down
4 changes: 2 additions & 2 deletions rten-tensor/src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ pub fn copy_into_slice<'a, T: Clone>(
let mut dest = NdTensorViewMut::from_data(src.shape(), dest);
for i0 in 0..src.size(0) {
for i1 in 0..src.size(1) {
let src = src.slice::<2, _>([i0, i1]);
let dest = dest.slice_mut::<2, _>([i0, i1]);
let src = src.slice_with([i0, i1]);
let dest = dest.slice_with_mut([i0, i1]);
copy_blocked(src, dest);
}
}
Expand Down
61 changes: 61 additions & 0 deletions rten-tensor/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::errors::{DimensionError, FromDataError, ReshapeError, SliceError};
use crate::index_iterator::{DynIndices, NdIndices};
use crate::overlap::{is_contiguous, may_have_internal_overlap};
use crate::slice_range::{IntoSliceItems, SliceItem};
use crate::type_num::{ConstUInt, U0, U1, U2, U3, U4, U5};

/// Return true if `permutation` is a valid permutation of dimensions for
/// a tensor of rank `ndim`.
Expand Down Expand Up @@ -1367,6 +1368,66 @@ impl_remove_dim!(3, 2);
impl_remove_dim!(4, 3);
impl_remove_dim!(5, 4);

/// Trait for slicing a layout with a range.
///
/// `R` is the type of the slice range. `IdxCount` is a marker type indicating
/// the number of items in `R` that are indices, as opposed to ranges.
pub trait SliceWith<R: IntoSliceItems, IdxCount: ConstUInt> {
/// The layout produced after slicing.
type Layout: Layout;

/// Slice the layout with a range.
///
/// Returns a tuple of `(offset_range, sliced_layout)` where `offset_range`
/// is the range of data from the original view that is used by the slice
/// and `sliced_layout` is the layout of the sliced view.
fn slice_with(&self, range: R) -> (Range<usize>, Self::Layout);
}

impl<R: IntoSliceItems, I: ConstUInt> SliceWith<R, I> for DynLayout {
type Layout = DynLayout;

fn slice_with(&self, range: R) -> (Range<usize>, Self::Layout) {
self.slice_dyn(range.into_slice_items().as_ref())
}
}

impl<R: IntoSliceItems, const N: usize> SliceWith<R, U0> for NdLayout<N> {
type Layout = NdLayout<N>;

fn slice_with(&self, range: R) -> (Range<usize>, Self::Layout) {
self.slice(range.into_slice_items().as_ref())
}
}

macro_rules! impl_slice_with {
($ndim:literal, $range_ndim:ty, $out_ndim:literal) => {
impl<R: IntoSliceItems> SliceWith<R, $range_ndim> for NdLayout<$ndim> {
type Layout = NdLayout<$out_ndim>;

fn slice_with(&self, range: R) -> (Range<usize>, Self::Layout) {
self.slice(range.into_slice_items().as_ref())
}
}
};
}

impl_slice_with!(1, U1, 0);
impl_slice_with!(2, U1, 1);
impl_slice_with!(2, U2, 0);
impl_slice_with!(3, U1, 2);
impl_slice_with!(3, U2, 1);
impl_slice_with!(3, U3, 0);
impl_slice_with!(4, U1, 3);
impl_slice_with!(4, U2, 2);
impl_slice_with!(4, U3, 1);
impl_slice_with!(4, U4, 0);
impl_slice_with!(5, U1, 4);
impl_slice_with!(5, U2, 3);
impl_slice_with!(5, U3, 2);
impl_slice_with!(5, U4, 1);
impl_slice_with!(5, U5, 0);

#[cfg(test)]
mod tests {
use std::iter::zip;
Expand Down
1 change: 1 addition & 0 deletions rten-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ mod overlap;
mod slice_range;
mod storage;
mod tensor;
pub mod type_num;

/// Trait for sources of random data for tensors, for use with [`Tensor::rand`].
pub trait RandomSource<T> {
Expand Down
109 changes: 102 additions & 7 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ use crate::iterators::{
};
use crate::layout::{
AsIndex, BroadcastLayout, DynLayout, IntoLayout, Layout, MatrixLayout, MutLayout, NdLayout,
OverlapPolicy, RemoveDim, ResizeLayout,
OverlapPolicy, RemoveDim, ResizeLayout, SliceWith,
};
use crate::overlap::may_have_internal_overlap;
use crate::storage::{CowData, IntoStorage, Storage, StorageMut, ViewData, ViewMutData};
use crate::type_num::IndexCount;
use crate::{Alloc, GlobalAlloc, IntoSliceItems, RandomSource, SliceItem};

/// The base type for multi-dimensional arrays. This consists of storage for
Expand Down Expand Up @@ -289,6 +290,35 @@ pub trait AsView: Layout {
self.view().slice_dyn(range)
}

/// Slice this tensor and return a view.
///
/// This is an alternative to [`slice`](Self::slice) and
/// [`slice_dyn`](Self::slice_dyn) that determines the dimension count of
/// the returned view automatically at compile time. If both this tensor's
/// layout and the range have a statically-known number of index terms,
/// the result will have a static rank. Otherwise it will have a dynamic
/// rank.
///
/// ```
/// use rten_tensor::prelude::*;
/// use rten_tensor::NdTensor;
///
/// let x = NdTensor::from([[1, 2], [3, 4]]);
/// let col = x.slice_with((.., 1)); // `col` is an `NdTensorView<i32, 1>`
/// assert_eq!(col.shape(), [2usize]);
/// assert_eq!(col.to_vec(), [2, 4]);
/// ```
#[allow(clippy::type_complexity)]
fn slice_with<R: IntoSliceItems + IndexCount>(
&self,
range: R,
) -> TensorBase<ViewData<Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
where
Self::Layout: SliceWith<R, R::Count, Layout: MutLayout>,
{
self.view().slice_with(range)
}

/// Return a slice of this tensor as an owned tensor.
///
/// This is more expensive than [`slice`](AsView::slice) as it copies the
Expand Down Expand Up @@ -768,6 +798,24 @@ impl<S: StorageMut, L: MutLayout> TensorBase<S, L> {
}
}

/// Slice this tensor and return a mutable view.
///
/// See [`slice_with`](AsView::slice_with) for notes on the layout of
/// the returned view.
pub fn slice_with_mut<R: IntoSliceItems + IndexCount>(
&mut self,
range: R,
) -> TensorBase<ViewMutData<S::Elem>, <L as SliceWith<R, R::Count>>::Layout>
where
L: SliceWith<R, R::Count, Layout: MutLayout>,
{
let (offset_range, sliced_layout) = self.layout.slice_with(range);
TensorBase {
data: self.data.slice_mut(offset_range),
layout: sliced_layout,
}
}

/// Slice this tensor and return a dynamic-rank view.
///
/// Fails if the range has more dimensions than the view or is out of bounds
Expand Down Expand Up @@ -1456,6 +1504,21 @@ impl<'a, T, L: Clone + MutLayout> TensorBase<ViewData<'a, T>, L> {
}
}

/// Slice this tensor and return a view. See [`AsView::slice_with`].
pub fn slice_with<R: IntoSliceItems + IndexCount>(
&self,
range: R,
) -> TensorBase<ViewData<'a, T>, <L as SliceWith<R, R::Count>>::Layout>
where
L: SliceWith<R, R::Count, Layout: MutLayout>,
{
let (offset_range, sliced_layout) = self.layout.slice_with(range);
TensorBase {
data: self.data.slice(offset_range),
layout: sliced_layout,
}
}

/// Remove all size-one dimensions from this tensor.
///
/// See [AsView::squeezed].
Expand Down Expand Up @@ -3370,7 +3433,7 @@ mod tests {
}

#[test]
fn test_slice_with_ndlayout() {
fn test_slice_on_ndlayout() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);

Expand All @@ -3389,7 +3452,7 @@ mod tests {
}

#[test]
fn test_slice_dyn_with_ndlayout() {
fn test_slice_dyn_on_ndlayout() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);

Expand All @@ -3403,7 +3466,7 @@ mod tests {
}

#[test]
fn test_slice_with_dynlayout() {
fn test_slice_on_dynlayout() {
let data = vec![1., 2., 3., 4.];
let tensor = Tensor::from_data(&[2, 2], data);

Expand All @@ -3417,7 +3480,7 @@ mod tests {
}

#[test]
fn test_slice_dyn_with_dynlayout() {
fn test_slice_dyn_on_dynlayout() {
let data = vec![1., 2., 3., 4.];
let tensor = Tensor::from_data(&[2, 2], data);

Expand Down Expand Up @@ -3454,6 +3517,38 @@ mod tests {
assert_eq!(tensor.to_vec(), &[1., 2., 8., 9.]);
}

#[test]
fn test_slice_with() {
// Slice static-rank array. The rank of the slice is inferred.
let data = NdTensor::from([[[1, 2, 3], [4, 5, 6]]]);
let row = data.slice_with((0, 0));
assert_eq!(row.shape(), [3usize]);
assert_eq!(row.data().unwrap(), &[1, 2, 3]);

// Slice dynamic-rank array. The rank of the slice is dynamic.
let data = Tensor::from([[[1, 2, 3], [4, 5, 6]]]);
let row = data.slice_with((0, 0));
assert_eq!(row.shape(), [3usize]);
assert_eq!(row.data().unwrap(), &[1, 2, 3]);
}

#[test]
fn test_slice_with_mut() {
// Slice static-rank array. The rank of the slice is inferred.
let mut data = NdTensor::from([[[1, 2, 3], [4, 5, 6]]]);
let mut row = data.slice_with_mut((0, 0));
row[[0usize]] = 5;
assert_eq!(row.shape(), [3usize]);
assert_eq!(row.data().unwrap(), &[5, 2, 3]);

// Slice dynamic-rank array. The rank of the slice is dynamic.
let mut data = Tensor::from([[[1, 2, 3], [4, 5, 6]]]);
let mut row = data.slice_with_mut((0, 0));
row[[0]] = 10;
assert_eq!(row.shape(), [3usize]);
assert_eq!(row.data().unwrap(), &[10, 2, 3]);
}

#[test]
fn test_squeezed() {
let data = &[1., 2., 3., 4., 5., 6.];
Expand Down Expand Up @@ -3562,8 +3657,8 @@ mod tests {
#[test]
fn test_to_array() {
let tensor = NdTensor::arange(1., 5., None).into_shape([2, 2]);
let col0: [f32; 2] = tensor.view().transposed().slice::<1, _>(0).to_array();
let col1: [f32; 2] = tensor.view().transposed().slice::<1, _>(1).to_array();
let col0: [f32; 2] = tensor.view().transposed().slice_with(0).to_array();
let col1: [f32; 2] = tensor.view().transposed().slice_with(1).to_array();
assert_eq!(col0, [1., 3.]);
assert_eq!(col1, [2., 4.]);
}
Expand Down
Loading

0 comments on commit 9b5d7fe

Please sign in to comment.