Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fallible variants of TensorBase::slice_with #365

Merged
merged 5 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions rten-tensor/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use std::error::Error;
use std::fmt::{Display, Formatter};

use crate::slice_range::SliceRange;

/// Error in a tensor operation if the dimension count is incorrect.
#[derive(Debug, PartialEq)]
pub struct DimensionError {}
Expand Down Expand Up @@ -47,28 +49,82 @@ impl Error for FromDataError {}
#[derive(Clone, Debug, PartialEq)]
pub enum SliceError {
/// The slice spec has more dimensions than the tensor being sliced.
TooManyDims,
TooManyDims {
/// Number of axes in the tensor.
ndim: usize,
/// Number of items in the slice spec.
range_ndim: usize,
},

/// An index in the slice spec is out of bounds for the corresponding tensor
/// dimension.
InvalidIndex,
InvalidIndex {
/// Axis that the error applies to.
axis: usize,
/// Index in the slice range.
index: isize,
/// Size of the dimension.
size: usize,
},

/// A range in the slice spec is out of bounds for the corresponding tensor
/// dimension.
InvalidRange,
InvalidRange {
/// Axis that the error applies to.
axis: usize,

/// The range item.
range: SliceRange,

/// Size of the dimension.
size: usize,
},

/// The step in a slice range is negative, in a context where this is not
/// supported.
InvalidStep,
InvalidStep {
/// Axis that the error applies to.
axis: usize,

/// Size of the dimension.
step: isize,
},

/// There is a mismatch between the actual and expected number of axes
/// in the output slice.
OutputDimsMismatch { actual: usize, expected: usize },
}

impl Display for SliceError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
SliceError::TooManyDims => write!(f, "slice spec has too many dims"),
SliceError::InvalidIndex => write!(f, "slice index is invalid"),
SliceError::InvalidRange => write!(f, "slice range is invalid"),
SliceError::InvalidStep => write!(f, "slice step is invalid"),
SliceError::TooManyDims { ndim, range_ndim } => {
write!(
f,
"slice range has {} items but tensor has only {} dims",
range_ndim, ndim
)
}
SliceError::InvalidIndex { axis, index, size } => write!(
f,
"slice index {} is invalid for axis ({}) of size {}",
index, axis, size
),
SliceError::InvalidRange { axis, range, size } => write!(
f,
"slice range {:?} is invalid for axis ({}) of size {}",
range, axis, size
),
SliceError::InvalidStep { axis, step } => {
write!(f, "slice step {} is invalid for axis {}", step, axis)
}
SliceError::OutputDimsMismatch { actual, expected } => {
write!(
f,
"slice output dims {} does not match expected dims {}",
actual, expected
)
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion rten-tensor/src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ impl LaneRanges {
(0..end).into()
})
.collect();
let (_range, sliced) = layout.slice_dyn(&slice_starts);
let (_range, sliced) = layout.slice_dyn(&slice_starts).unwrap();
let offsets = Offsets::new(&sliced);
LaneRanges {
offsets,
Expand Down
Loading
Loading