Skip to content

Commit

Permalink
Simplify SliceWith usage by changing Layout assoc type bound
Browse files Browse the repository at this point in the history
Promise that the `Layout` associated type is a `MutLayout`, so we don't have to
add that bound everywhere that uses it.
  • Loading branch information
robertknight committed Sep 19, 2024
1 parent 5a7ca59 commit f7a5523
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion rten-tensor/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ impl_remove_dim!(5, 4);
/// the number of items in `R` that are indices, as opposed to ranges.
pub trait SliceWith<R: IntoSliceItems, IdxCount: OptionalUInt> {
/// The layout produced after slicing.
type Layout: Layout;
type Layout: MutLayout;

/// Slice the layout with a range.
///
Expand Down
12 changes: 6 additions & 6 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ pub trait AsView: Layout {
range: R,
) -> TensorBase<ViewData<Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
where
Self::Layout: SliceWith<R, R::Count, Layout: MutLayout>,
Self::Layout: SliceWith<R, R::Count>,
{
self.view().slice_with(range)
}
Expand All @@ -330,7 +330,7 @@ pub trait AsView: Layout {
SliceError,
>
where
Self::Layout: SliceWith<R, R::Count, Layout: MutLayout>,
Self::Layout: SliceWith<R, R::Count>,
{
self.view().try_slice_with(range)
}
Expand Down Expand Up @@ -825,7 +825,7 @@ impl<S: StorageMut, L: MutLayout> TensorBase<S, L> {
range: R,
) -> TensorBase<ViewMutData<S::Elem>, <L as SliceWith<R, R::Count>>::Layout>
where
L: SliceWith<R, R::Count, Layout: MutLayout>,
L: SliceWith<R, R::Count>,
{
self.try_slice_with_mut(range).expect("slice failed")
}
Expand Down Expand Up @@ -853,7 +853,7 @@ impl<S: StorageMut, L: MutLayout> TensorBase<S, L> {
range: R,
) -> Result<TensorBase<ViewMutData<S::Elem>, <L as SliceWith<R, R::Count>>::Layout>, SliceError>
where
L: SliceWith<R, R::Count, Layout: MutLayout>,
L: SliceWith<R, R::Count>,
{
let (offset_range, sliced_layout) = self.layout.slice_with(range)?;
Ok(TensorBase {
Expand Down Expand Up @@ -1541,7 +1541,7 @@ impl<'a, T, L: Clone + MutLayout> TensorBase<ViewData<'a, T>, L> {
range: R,
) -> TensorBase<ViewData<'a, T>, <L as SliceWith<R, R::Count>>::Layout>
where
L: SliceWith<R, R::Count, Layout: MutLayout>,
L: SliceWith<R, R::Count>,
{
self.try_slice_with(range).expect("slice failed")
}
Expand All @@ -1554,7 +1554,7 @@ impl<'a, T, L: Clone + MutLayout> TensorBase<ViewData<'a, T>, L> {
range: R,
) -> Result<TensorBase<ViewData<'a, T>, <L as SliceWith<R, R::Count>>::Layout>, SliceError>
where
L: SliceWith<R, R::Count, Layout: MutLayout>,
L: SliceWith<R, R::Count>,
{
let (offset_range, sliced_layout) = self.layout.slice_with(range)?;
Ok(TensorBase {
Expand Down

0 comments on commit f7a5523

Please sign in to comment.