Skip to content

Commit

Permalink
Merge pull request #366 from robertknight/slice-copy-inferred-type
Browse files Browse the repository at this point in the history
Infer layout for `slice_copy` based on layout and range
  • Loading branch information
robertknight authored Sep 20, 2024
2 parents aecc370 + 1861a89 commit 8af2ddf
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::fmt::Debug;
use std::mem::MaybeUninit;
use std::ops::{Index, IndexMut, Range};

Expand Down Expand Up @@ -339,27 +340,42 @@ pub trait AsView: Layout {
///
/// This is more expensive than [`slice`](AsView::slice) as it copies the
/// data, but is more flexible as it supports ranges with negative steps.
fn slice_copy<R: Clone + IntoSliceItems>(&self, range: R) -> Tensor<Self::Elem>
#[allow(clippy::type_complexity)]
fn slice_copy<R: Clone + IntoSliceItems + IndexCount>(
&self,
range: R,
) -> TensorBase<Vec<Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
where
Self::Elem: Clone,
Self::Layout: SliceWith<
R,
R::Count,
Layout: for<'a> Layout<Index<'a>: TryFrom<&'a [usize], Error: Debug>>,
>,
{
self.slice_copy_in(GlobalAlloc::new(), range)
}

/// Variant of [`slice_copy`](AsView::slice_copy) which takes an allocator.
fn slice_copy_in<A: Alloc, R: Clone + IntoSliceItems>(
#[allow(clippy::type_complexity)]
fn slice_copy_in<A: Alloc, R: Clone + IntoSliceItems + IndexCount>(
&self,
pool: A,
range: R,
) -> Tensor<Self::Elem>
) -> TensorBase<Vec<Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
where
Self::Elem: Clone,
Self::Layout: SliceWith<
R,
R::Count,
Layout: for<'a> Layout<Index<'a>: TryFrom<&'a [usize], Error: Debug>>,
>,
{
// Fast path for slice ranges supported by `Tensor::slice`. This includes
// all ranges except those with a negative step. This benefits from
// optimizations that `Tensor::to_tensor` has for slices that are already
// contiguous or have a small number of dims.
if let Ok(slice_view) = self.try_slice_dyn(range.clone()) {
if let Ok(slice_view) = self.try_slice_with(range.clone()) {
return slice_view.to_tensor_in(pool);
}

Expand Down Expand Up @@ -388,7 +404,9 @@ pub trait AsView: Layout {
sliced_data.set_len(sliced_len);
}

Tensor::from_data(&sliced_shape, sliced_data)
let sliced_shape = sliced_shape.as_slice().try_into().expect("slice failed");

TensorBase::from_data(sliced_shape, sliced_data)
}

/// Return a view of this tensor with all dimensions of size 1 removed.
Expand Down

0 comments on commit 8af2ddf

Please sign in to comment.