diff --git a/rten-examples/src/bert_qa.rs b/rten-examples/src/bert_qa.rs index e7350250..d77d7f59 100644 --- a/rten-examples/src/bert_qa.rs +++ b/rten-examples/src/bert_qa.rs @@ -159,13 +159,13 @@ fn extract_nbest_answers<'a>( let min_start = 1; // Ignore [CLS] token at start. let max_end = end_probs.size(1) - 1; // Ignore [SEP] token at end. let mut span_scores: Vec<(usize, usize, f32)> = start_probs - .slice::<1, _>((0, min_start..max_end)) + .slice((0, min_start..max_end)) .iter() .enumerate() .map(|(start_pos, start_score)| { let start_pos = start_pos + min_start; let (relative_end_pos, end_score) = end_probs - .slice::<1, _>((0, start_pos..(start_pos + max_answer_len).min(max_end))) + .slice((0, start_pos..(start_pos + max_answer_len).min(max_end))) .iter() .enumerate() .max_by(|(_pos_a, score_a), (_pos_b, score_b)| score_a.total_cmp(score_b)) diff --git a/rten-examples/src/deeplab.rs b/rten-examples/src/deeplab.rs index f9fe6747..7f19d5a3 100644 --- a/rten-examples/src/deeplab.rs +++ b/rten-examples/src/deeplab.rs @@ -131,7 +131,7 @@ fn main() -> Result<(), Box> { output.permute(&[0, 2, 3, 1]); // (N,class,H,W) => (N,H,W,class) let seg_classes: NdTensor = output - .slice_dyn(0) + .slice(0) .arg_max(-1, false /* keep_dims */)? .try_into()?; let [out_height, out_width] = seg_classes.shape(); diff --git a/rten-examples/src/depth_anything.rs b/rten-examples/src/depth_anything.rs index 4149b1c9..6414b5b1 100644 --- a/rten-examples/src/depth_anything.rs +++ b/rten-examples/src/depth_anything.rs @@ -105,7 +105,7 @@ fn main() -> Result<(), Box> { // Resize output map back to original input size and write to file. let resized = output.resize_image([orig_height, orig_width])?; - let resized = resized.slice::<3, _>(0); + let resized = resized.nd_view::<4>().slice(0); write_image(&args.output, resized)?; Ok(()) diff --git a/rten-examples/src/jina_similarity.rs b/rten-examples/src/jina_similarity.rs index 26930b0d..5ac57f50 100644 --- a/rten-examples/src/jina_similarity.rs +++ b/rten-examples/src/jina_similarity.rs @@ -102,7 +102,7 @@ fn embed_sentence_batch( let token_ids = encoded.token_ids(); for (tid, input_id) in token_ids .iter() - .zip(input_ids.slice_mut_dyn((i, ..token_ids.len())).iter_mut()) + .zip(input_ids.slice_mut((i, ..token_ids.len())).iter_mut()) { *input_id = *tid as i32; } @@ -110,10 +110,10 @@ fn embed_sentence_batch( // Generate attention mask, set to 1 for non-padding tokens and 0 for // padding tokens. - let mut attention_mask = Tensor::zeros(&[batch, max_sequence_len]); + let mut attention_mask = NdTensor::zeros([batch, max_sequence_len]); for (i, encoded) in encoded.iter().enumerate() { attention_mask - .slice_mut::<1, _>((i, ..encoded.token_ids().len())) + .slice_mut((i, ..encoded.token_ids().len())) .fill(1i32); } @@ -127,9 +127,9 @@ fn embed_sentence_batch( // Generate token type IDs if this model needs them. These are all zeros // since each item has just one sequence. - let type_ids: Tensor; + let type_ids: NdTensor; if let Some(type_ids_id) = model.find_node("token_type_ids") { - type_ids = Tensor::zeros(&[batch, max_sequence_len]); + type_ids = NdTensor::zeros([batch, max_sequence_len]); inputs.push((type_ids_id, type_ids.view().into())); } @@ -146,7 +146,7 @@ fn embed_sentence_batch( // Take the mean of the non-padding elements along the sequence // dimension. let seq_len = input.token_ids().len(); - item.slice_dyn(..seq_len) + item.slice(..seq_len) .reduce_mean(Some(&[0]), false /* keep_dims */) .unwrap() }) @@ -231,7 +231,7 @@ fn main() -> Result<(), Box> { // (1, embed) @ (embed, batch) => (1, batch) let similarities = embeddings - .slice::<2, _>(..1) + .slice(..1) .matmul(embeddings.transposed().into())?; // Sort results by similarity to the query. @@ -240,7 +240,7 @@ fn main() -> Result<(), Box> { // all be "high" values (close to 1.0). They should be used only for // comparison with other scores. let mut scores: Vec<(usize, f32)> = similarities - .slice_dyn(0) + .slice(0) .iter() .copied() .enumerate() diff --git a/rten-examples/src/piper.rs b/rten-examples/src/piper.rs index 809fcdfe..3e9a2fbf 100644 --- a/rten-examples/src/piper.rs +++ b/rten-examples/src/piper.rs @@ -223,7 +223,7 @@ fn main() -> Result<(), Box> { // Convert audio samples from float to 16-bit ints and write to output .wav // file. - let int_samples = audio_float_to_int16(samples.slice::<1, _>((0, 0, 0)), None); + let int_samples = audio_float_to_int16(samples.slice((0, 0, 0)), None); let wav_file = BufWriter::new(File::create("output.wav")?); let mut wav_writer = WavWriter::new( diff --git a/rten-examples/src/rmbg.rs b/rten-examples/src/rmbg.rs index 1a65b502..4a4aec0f 100644 --- a/rten-examples/src/rmbg.rs +++ b/rten-examples/src/rmbg.rs @@ -119,7 +119,7 @@ fn main() -> Result<(), Box> { let bg_color = [0., 1., 0.]; // RGB fill_mask( image.view_mut(), - background_mask.slice::<2, _>([0, 0]), // Extract first mask and channel + background_mask.slice([0, 0]), // Extract first mask and channel bg_color, ); diff --git a/rten-examples/src/segment_anything.rs b/rten-examples/src/segment_anything.rs index 9ac63665..463dc552 100644 --- a/rten-examples/src/segment_anything.rs +++ b/rten-examples/src/segment_anything.rs @@ -207,11 +207,9 @@ fn main() -> Result<(), Box> { // Resize the output mask to match the original image and save to disk. let pred_masks: NdTensor = pred_masks.try_into()?; let [_batch, _point_batch, _mask, mask_h, mask_w] = pred_masks.shape(); - let best_mask = pred_masks - .slice::<2, _>((0, 0, 0)) - .reshaped([1, 1, mask_h, mask_w]); + let best_mask = pred_masks.slice((0, 0, 0)).reshaped([1, 1, mask_h, mask_w]); let resized_mask = best_mask.resize_image([image_h, image_w])?; - write_image("segmented.png", resized_mask.slice::<3, _>(0).nd_view())?; + write_image("segmented.png", resized_mask.nd_view::<4>().slice(0))?; Ok(()) } diff --git a/rten-examples/src/trocr.rs b/rten-examples/src/trocr.rs index 0e68c7f9..c72730b0 100644 --- a/rten-examples/src/trocr.rs +++ b/rten-examples/src/trocr.rs @@ -97,7 +97,7 @@ fn main() -> Result<(), Box> { image.insert_axis(0); // Add batch dim // From `image_size` in config.json. - let mut image = image.resize_image([384, 384])?; + let mut image: NdTensor<_, 4> = image.resize_image([384, 384])?.try_into()?; // Values taken from `preprocessor_config.json`. let mean = [0.5, 0.5, 0.5]; diff --git a/rten-examples/src/yolo.rs b/rten-examples/src/yolo.rs index a14bac63..db9baec4 100644 --- a/rten-examples/src/yolo.rs +++ b/rten-examples/src/yolo.rs @@ -146,11 +146,11 @@ fn main() -> Result<(), Box> { let scale_x = image_width as f32 / model_in_w as f32; // [batch, n_boxes, coord] - let boxes = output.slice::<3, _>((.., ..4, ..)).permuted([0, 2, 1]); + let boxes = output.slice((.., ..4, ..)).permuted([0, 2, 1]); // [batch, n_classes, n_boxes]. The `n_boxes` coord is last because that // is what `non_max_suppression` requires. - let scores = output.slice::<3, _>((.., 4.., ..)); + let scores = output.slice((.., 4.., ..)); let iou_threshold = 0.3; let score_threshold = 0.25; diff --git a/rten-generate/src/generator.rs b/rten-generate/src/generator.rs index 27604c2b..d6d3c666 100644 --- a/rten-generate/src/generator.rs +++ b/rten-generate/src/generator.rs @@ -602,7 +602,7 @@ impl<'a> Generator<'a> { // Sample output token. let logits: NdTensor = outputs.remove(0).try_into().map_err(wrap_error)?; - let next_id = self.sampler.sample(logits.slice_with((0, -1))); + let next_id = self.sampler.sample(logits.slice((0, -1))); // Update the self-attention key-value cache. // diff --git a/rten-generate/src/sampler.rs b/rten-generate/src/sampler.rs index 536e844a..6dade8b6 100644 --- a/rten-generate/src/sampler.rs +++ b/rten-generate/src/sampler.rs @@ -97,7 +97,7 @@ impl Sampler for TopKSampler { let topk_index = multinomial(&mut self.rng.borrow_mut(), probs.nd_view()) .expect("probs should be non-empty and sum to 1"); - let token_id = topk_indices.slice_with(topk_index).item().copied().unwrap(); + let token_id = topk_indices.slice(topk_index).item().copied().unwrap(); token_id as TokenId } } diff --git a/rten-imageproc/src/drawing.rs b/rten-imageproc/src/drawing.rs index d3755c1b..47e34b3a 100644 --- a/rten-imageproc/src/drawing.rs +++ b/rten-imageproc/src/drawing.rs @@ -465,7 +465,7 @@ impl<'a, T: Copy + Default> Painter<'a, T> { pub fn draw_polygon(&mut self, points: &[Point]) { for i in 0..3 { draw_polygon( - self.surface.slice_with_mut([i]), + self.surface.slice_mut([i]), points, self.state.stroke[i], self.state.stroke_width, @@ -600,9 +600,9 @@ mod tests { let expected_g = expected_r.map(|&x| if x == r { g } else { 0 }); let expected_b = expected_r.map(|&x| if x == r { b } else { 0 }); - compare_images(img.slice_with([0]), expected_r.view()); - compare_images(img.slice_with([1]), expected_g.view()); - compare_images(img.slice_with([2]), expected_b.view()); + compare_images(img.slice([0]), expected_r.view()); + compare_images(img.slice([1]), expected_g.view()); + compare_images(img.slice([2]), expected_b.view()); } #[test] diff --git a/rten-imageproc/src/normalize.rs b/rten-imageproc/src/normalize.rs index 624c93b6..b1597db6 100644 --- a/rten-imageproc/src/normalize.rs +++ b/rten-imageproc/src/normalize.rs @@ -32,7 +32,7 @@ pub fn normalize_image( for chan in 0..n_chans { let inv_std_dev = 1. / std_dev[chan]; - img.slice_with_mut(chan) + img.slice_mut(chan) .apply(|x| (x - mean[chan]) * inv_std_dev); } } diff --git a/rten-tensor/src/copy.rs b/rten-tensor/src/copy.rs index 2f123c0f..813e4bb2 100644 --- a/rten-tensor/src/copy.rs +++ b/rten-tensor/src/copy.rs @@ -205,8 +205,8 @@ pub fn copy_into_slice<'a, T: Clone>( let mut dest = NdTensorViewMut::from_data(src.shape(), dest); for i0 in 0..src.size(0) { for i1 in 0..src.size(1) { - let src = src.slice_with([i0, i1]); - let dest = dest.slice_with_mut([i0, i1]); + let src = src.slice([i0, i1]); + let dest = dest.slice_mut([i0, i1]); copy_blocked(src, dest); } } @@ -437,7 +437,7 @@ fn copy_range_into_slice_inner( } else { // Iterate over views of outermost dimension and recurse. for i0 in ranges[0] { - let src_slice = src.slice_dyn(i0); + let src_slice = src.slice(i0); let (dest_slice, dest_tail) = dest.split_at_mut(src_slice.len()); copy_range_into_slice_inner(src_slice, dest_slice, &ranges[1..]); diff --git a/rten-tensor/src/iterators.rs b/rten-tensor/src/iterators.rs index a4f8d819..cb31aedc 100644 --- a/rten-tensor/src/iterators.rs +++ b/rten-tensor/src/iterators.rs @@ -779,7 +779,7 @@ impl<'a, T, L: MutLayout> Iterator for InnerIterDyn<'a, T, L> { fn next(&mut self) -> Option { self.outer_indices.next().map(|idx| { let slice_items = to_slice_items(&idx); - self.view.slice_with(slice_items.as_slice()) + self.view.slice(slice_items.as_slice()) }) } @@ -855,7 +855,7 @@ impl<'a, T, L: MutLayout> Iterator for InnerIterDynMut<'a, T, L> { fn next(&mut self) -> Option { self.outer_indices.next().map(|idx| { let slice_items = to_slice_items(&idx); - let view: TensorViewMut<'_, T> = self.view.slice_mut_dyn(slice_items.as_slice()); + let view: TensorViewMut<'_, T> = self.view.slice_mut(slice_items.as_slice()); unsafe { // Safety: Outer view is non-broadcasting, and we increment the // outer index each time, so returned views will not overlap. diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs index 5ca4005f..9a5ae58a 100644 --- a/rten-tensor/src/tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -262,68 +262,36 @@ pub trait AsView: Layout { self.view().transposed() } - /// Slice this tensor and return a dynamic-rank view. - /// - /// Fails if the range has more dimensions than the view or is out of bounds - /// for any dimension. - fn try_slice_dyn( - &self, - range: R, - ) -> Result, SliceError> { - self.view().try_slice_dyn(range) - } - - /// Slice this tensor and return a static-rank view with `M` dimensions. - /// - /// Use [AsView::slice_dyn] instead if the number of dimensions in the - /// returned view is unknown at compile time. - /// - /// This method is cheap as it does not copy the data, but does not support - /// ranges with negative steps. For that use [`slice_copy`](AsView::slice_copy). - /// - /// Panics if the dimension count of the result is not `M`. - fn slice(&self, range: R) -> NdTensorView { - self.view().slice(range) - } - - /// Slice this tensor and return a dynamic-rank view. - fn slice_dyn(&self, range: R) -> TensorView { - self.view().slice_dyn(range) - } - /// Slice this tensor and return a view. /// - /// This is an alternative to [`slice`](Self::slice) and - /// [`slice_dyn`](Self::slice_dyn) that determines the dimension count of - /// the returned view automatically at compile time. If both this tensor's - /// layout and the range have a statically-known number of index terms, - /// the result will have a static rank. Otherwise it will have a dynamic - /// rank. + /// If both this tensor's layout and the range have a statically-known + /// number of index terms, the result will have a static rank. Otherwise it + /// will have a dynamic rank. /// /// ``` /// use rten_tensor::prelude::*; /// use rten_tensor::NdTensor; /// /// let x = NdTensor::from([[1, 2], [3, 4]]); - /// let col = x.slice_with((.., 1)); // `col` is an `NdTensorView` + /// let col = x.slice((.., 1)); // `col` is an `NdTensorView` /// assert_eq!(col.shape(), [2usize]); /// assert_eq!(col.to_vec(), [2, 4]); /// ``` #[allow(clippy::type_complexity)] - fn slice_with( + fn slice( &self, range: R, ) -> TensorBase, >::Layout> where Self::Layout: SliceWith, { - self.view().slice_with(range) + self.view().slice(range) } - /// A variant of [`slice_with`](Self::slice_with) that returns a result + /// A variant of [`slice`](Self::slice) that returns a result /// instead of panicking. #[allow(clippy::type_complexity)] - fn try_slice_with( + fn try_slice( &self, range: R, ) -> Result< @@ -333,7 +301,7 @@ pub trait AsView: Layout { where Self::Layout: SliceWith, { - self.view().try_slice_with(range) + self.view().try_slice(range) } /// Return a slice of this tensor as an owned tensor. @@ -375,7 +343,7 @@ pub trait AsView: Layout { // all ranges except those with a negative step. This benefits from // optimizations that `Tensor::to_tensor` has for slices that are already // contiguous or have a small number of dims. - if let Ok(slice_view) = self.try_slice_with(range.clone()) { + if let Ok(slice_view) = self.try_slice(range.clone()) { return slice_view.to_tensor_in(pool); } @@ -804,69 +772,24 @@ impl TensorBase { } } - /// Slice this tensor and return a static-rank view with `M` dimensions. - /// - /// Use [AsView::slice_dyn] instead if the number of dimensions in the - /// returned view is unknown at compile time. - /// - /// Panics if the dimension count is not `M`. - pub fn slice_mut( - &mut self, - range: R, - ) -> NdTensorViewMut { - let range = range.into_slice_items(); - let (offset_range, sliced_layout) = - self.layout.slice(range.as_ref()).expect("slice failed"); - NdTensorViewMut { - data: self.data.slice_mut(offset_range), - layout: sliced_layout, - } - } - - /// Slice this tensor and return a dynamic-rank view. - pub fn slice_mut_dyn(&mut self, range: R) -> TensorViewMut { - let range = range.into_slice_items(); - let (offset_range, sliced_layout) = - self.layout.slice_dyn(range.as_ref()).expect("slice failed"); - TensorViewMut { - data: self.data.slice_mut(offset_range), - layout: sliced_layout, - } - } - /// Slice this tensor and return a mutable view. /// - /// See [`slice_with`](AsView::slice_with) for notes on the layout of - /// the returned view. - pub fn slice_with_mut( + /// See [`slice`](AsView::slice) for notes on the layout of the returned + /// view. + pub fn slice_mut( &mut self, range: R, ) -> TensorBase, >::Layout> where L: SliceWith, { - self.try_slice_with_mut(range).expect("slice failed") + self.try_slice_mut(range).expect("slice failed") } - /// Slice this tensor and return a dynamic-rank view. - /// - /// Fails if the range has more dimensions than the view or is out of bounds - /// for any dimension. - pub fn try_slice_mut( - &mut self, - range: R, - ) -> Result, SliceError> { - let (offset_range, layout) = self.layout.slice_dyn(range.into_slice_items().as_ref())?; - Ok(TensorBase { - data: self.data.slice_mut(offset_range), - layout, - }) - } - - /// A variant of [`slice_with_mut`](Self::slice_with_mut) that returns a + /// A variant of [`slice_mut`](Self::slice_mut) that returns a /// result instead of panicking. #[allow(clippy::type_complexity)] - pub fn try_slice_with_mut( + pub fn try_slice_mut( &mut self, range: R, ) -> Result, >::Layout>, SliceError> @@ -1533,41 +1456,21 @@ impl<'a, T, L: Clone + MutLayout> TensorBase, L> { } } - /// Slice this tensor and return a static-rank view. See [AsView::slice]. - pub fn slice(&self, range: R) -> NdTensorView<'a, T, M> { - let range = range.into_slice_items(); - let (offset_range, sliced_layout) = self.layout.slice(range.as_ref()).unwrap(); - NdTensorView { - data: self.data.slice(offset_range), - layout: sliced_layout, - } - } - - /// Slice this tensor and return a dynamic-rank view. See [AsView::slice_dyn]. - pub fn slice_dyn(&self, range: R) -> TensorView<'a, T> { - let range = range.into_slice_items(); - let (offset_range, sliced_layout) = self.layout.slice_dyn(range.as_ref()).unwrap(); - TensorView { - data: self.data.slice(offset_range), - layout: sliced_layout, - } - } - - /// Slice this tensor and return a view. See [`AsView::slice_with`]. - pub fn slice_with( + /// Slice this tensor and return a view. See [`AsView::slice`]. + pub fn slice( &self, range: R, ) -> TensorBase, >::Layout> where L: SliceWith, { - self.try_slice_with(range).expect("slice failed") + self.try_slice(range).expect("slice failed") } - /// A variant of [`slice_with`](Self::slice_with) that returns a result + /// A variant of [`slice`](Self::slice) that returns a result /// instead of panicking. #[allow(clippy::type_complexity)] - pub fn try_slice_with( + pub fn try_slice( &self, range: R, ) -> Result, >::Layout>, SliceError> @@ -2496,8 +2399,8 @@ mod tests { let mut transposed = tensor.view_mut(); transposed.permute([1, 0]); - transposed.slice_with_mut(0).assign_array([1, 2]); - transposed.slice_with_mut(1).assign_array([3, 4]); + transposed.slice_mut(0).assign_array([1, 2]); + transposed.slice_mut(1).assign_array([3, 4]); assert_eq!(tensor.iter().copied().collect::>(), [1, 3, 2, 4]); } @@ -3495,117 +3398,32 @@ mod tests { } #[test] - fn test_slice_on_ndlayout() { - let data = vec![1., 2., 3., 4.]; - let tensor = NdTensor::from_data([2, 2], data); - - let row_one = tensor.slice(0); - assert_eq!(row_one[[0]], 1.); - assert_eq!(row_one[[1]], 2.); - - let row_two = tensor.slice(1); - assert_eq!(row_two[[0]], 3.); - assert_eq!(row_two[[1]], 4.); - - // Slice empty tensor - let empty = NdTensor::::zeros([0, 10]); - let col_one = empty.slice((.., 2..3)); - assert_eq!(col_one.shape(), [0, 1]); - } - - #[test] - fn test_slice_dyn_on_ndlayout() { - let data = vec![1., 2., 3., 4.]; - let tensor = NdTensor::from_data([2, 2], data); - - let row_one = tensor.slice_dyn(0); - assert_eq!(row_one[[0]], 1.); - assert_eq!(row_one[[1]], 2.); - - let row_two = tensor.slice_dyn(1); - assert_eq!(row_two[[0]], 3.); - assert_eq!(row_two[[1]], 4.); - } - - #[test] - fn test_slice_on_dynlayout() { - let data = vec![1., 2., 3., 4.]; - let tensor = Tensor::from_data(&[2, 2], data); - - let row_one = tensor.slice(0); - assert_eq!(row_one[[0]], 1.); - assert_eq!(row_one[[1]], 2.); - - let row_two = tensor.slice(1); - assert_eq!(row_two[[0]], 3.); - assert_eq!(row_two[[1]], 4.); - } - - #[test] - fn test_slice_dyn_on_dynlayout() { - let data = vec![1., 2., 3., 4.]; - let tensor = Tensor::from_data(&[2, 2], data); - - let row_one = tensor.slice_dyn(0); - assert_eq!(row_one[[0]], 1.); - assert_eq!(row_one[[1]], 2.); - - let row_two = tensor.slice_dyn(1); - assert_eq!(row_two[[0]], 3.); - assert_eq!(row_two[[1]], 4.); - } - - #[test] - fn test_slice_mut() { - let data = vec![1., 2., 3., 4.]; - let mut tensor = NdTensor::from_data([2, 2], data); - - let mut row = tensor.slice_mut(1); - row[[0]] = 8.; - row[[1]] = 9.; - - assert_eq!(tensor.to_vec(), &[1., 2., 8., 9.]); - } - - #[test] - fn test_slice_mut_dyn() { - let data = vec![1., 2., 3., 4.]; - let mut tensor = NdTensor::from_data([2, 2], data); - - let mut row = tensor.slice_mut_dyn(1); - row[[0]] = 8.; - row[[1]] = 9.; - - assert_eq!(tensor.to_vec(), &[1., 2., 8., 9.]); - } - - #[test] - fn test_slice_with() { + fn test_slice() { // Slice static-rank array. The rank of the slice is inferred. let data = NdTensor::from([[[1, 2, 3], [4, 5, 6]]]); - let row = data.slice_with((0, 0)); + let row = data.slice((0, 0)); assert_eq!(row.shape(), [3usize]); assert_eq!(row.data().unwrap(), &[1, 2, 3]); // Slice dynamic-rank array. The rank of the slice is dynamic. let data = Tensor::from([[[1, 2, 3], [4, 5, 6]]]); - let row = data.slice_with((0, 0)); + let row = data.slice((0, 0)); assert_eq!(row.shape(), [3usize]); assert_eq!(row.data().unwrap(), &[1, 2, 3]); } #[test] - fn test_slice_with_mut() { + fn test_slice_mut() { // Slice static-rank array. The rank of the slice is inferred. let mut data = NdTensor::from([[[1, 2, 3], [4, 5, 6]]]); - let mut row = data.slice_with_mut((0, 0)); + let mut row = data.slice_mut((0, 0)); row[[0usize]] = 5; assert_eq!(row.shape(), [3usize]); assert_eq!(row.data().unwrap(), &[5, 2, 3]); // Slice dynamic-rank array. The rank of the slice is dynamic. let mut data = Tensor::from([[[1, 2, 3], [4, 5, 6]]]); - let mut row = data.slice_with_mut((0, 0)); + let mut row = data.slice_mut((0, 0)); row[[0usize]] = 10; assert_eq!(row.shape(), [3usize]); assert_eq!(row.data().unwrap(), &[10, 2, 3]); @@ -3719,8 +3537,8 @@ mod tests { #[test] fn test_to_array() { let tensor = NdTensor::arange(1., 5., None).into_shape([2, 2]); - let col0: [f32; 2] = tensor.view().transposed().slice_with(0).to_array(); - let col1: [f32; 2] = tensor.view().transposed().slice_with(1).to_array(); + let col0: [f32; 2] = tensor.view().transposed().slice(0).to_array(); + let col1: [f32; 2] = tensor.view().transposed().slice(1).to_array(); assert_eq!(col0, [1., 3.]); assert_eq!(col1, [2., 4.]); } @@ -3857,30 +3675,14 @@ mod tests { let data = vec![1., 2., 3., 4.]; let tensor = Tensor::from_data(&[2, 2], data); - let row = tensor.try_slice_dyn(0); + let row = tensor.try_slice(0); assert!(row.is_ok()); assert_eq!(row.unwrap().data(), Some([1., 2.].as_slice())); - let row = tensor.try_slice_dyn(1); + let row = tensor.try_slice(1); assert!(row.is_ok()); - let row = tensor.try_slice_dyn(2); - assert!(row.is_err()); - } - - #[test] - fn test_try_slice_with() { - let data = vec![1., 2., 3., 4.]; - let tensor = Tensor::from_data(&[2, 2], data); - - let row = tensor.try_slice_with(0); - assert!(row.is_ok()); - assert_eq!(row.unwrap().data(), Some([1., 2.].as_slice())); - - let row = tensor.try_slice_with(1); - assert!(row.is_ok()); - - let row = tensor.try_slice_with(2); + let row = tensor.try_slice(2); assert!(row.is_err()); } @@ -3897,24 +3699,7 @@ mod tests { let row = tensor.try_slice_mut(1); assert!(row.is_ok()); - let row = tensor.try_slice_dyn(2); - assert!(row.is_err()); - } - - #[test] - fn test_try_slice_with_mut() { - let data = vec![1., 2., 3., 4.]; - let mut tensor = Tensor::from_data(&[2, 2], data); - - let mut row = tensor.try_slice_with_mut(0).unwrap(); - row[[0]] += 1.; - row[[1]] += 1.; - assert_eq!(row.data(), Some([2., 3.].as_slice())); - - let row = tensor.try_slice_with_mut(1); - assert!(row.is_ok()); - - let row = tensor.try_slice_with(2); + let row = tensor.try_slice(2); assert!(row.is_err()); } diff --git a/rten-tensor/src/type_num.rs b/rten-tensor/src/type_num.rs index 9030279c..d0dbff48 100644 --- a/rten-tensor/src/type_num.rs +++ b/rten-tensor/src/type_num.rs @@ -1,8 +1,8 @@ //! Traits and types for compile-time arithmetic. //! //! These types are used in various tensor methods, such as -//! [`slice_with`](crate::TensorBase::slice_with), as part of computing the -//! layout of the result at compile time. +//! [`slice`](crate::TensorBase::slice), as part of computing the layout of the +//! result at compile time. use std::ops::{Range, RangeFrom, RangeFull, RangeTo}; diff --git a/src/gemm.rs b/src/gemm.rs index cb0be8cf..0c7e32db 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -722,7 +722,7 @@ fn gemv( for (k_block, a_block) in range_chunks(0..a_cols, k_block_size).zip(a_data.chunks(k_block_size)) { - let b_block = b.slice_with((k_block, col_block.clone())); + let b_block = b.slice((k_block, col_block.clone())); kernel.gemv_kernel(out_chunk, a_block, b_block, alpha, effective_beta); // Reset `beta` so that subsequent updates for each column @@ -815,7 +815,7 @@ fn gemm_impl( if let (1, GemmInputA::Unpacked(a), GemmInputB::Unpacked(b)) = (a.rows(), a, b) { gemv( kernel, - a.slice_with(0), + a.slice(0), b, output_mat.view_mut(), alpha, diff --git a/src/gemm/kernels/simd_generic.rs b/src/gemm/kernels/simd_generic.rs index d126c911..92546334 100644 --- a/src/gemm/kernels/simd_generic.rs +++ b/src/gemm/kernels/simd_generic.rs @@ -148,7 +148,7 @@ unsafe fn simd_gemv_transposed( simd_gemv_fallback( &mut out[last_col_tile.clone()], a, - b.slice_with((.., last_col_tile)), + b.slice((.., last_col_tile)), alpha, beta, ); diff --git a/src/ops/conv.rs b/src/ops/conv.rs index a595e64a..247be292 100644 --- a/src/ops/conv.rs +++ b/src/ops/conv.rs @@ -46,10 +46,10 @@ where let mut n_init = 0; for n in 0..batch { - let mut out_item = output.slice_with_mut([n]); + let mut out_item = output.slice_mut([n]); let out_row_stride = out_item.stride(0); - let in_mat = input.slice_with([n]).reshaped([in_c, in_h * in_w]); + let in_mat = input.slice([n]).reshaped([in_c, in_h * in_w]); gemm.gemm_uninit_bias( out_item.data_mut().unwrap(), @@ -237,12 +237,12 @@ where let out_chan_start = group * out_channels_per_group; let out_chans = out_chan_start..out_chan_start + out_channels_per_group; - let in_group = input.slice_with((.., in_chan_start..in_chan_end)); - let mut out_group = output.slice_with_mut((.., out_chans.clone())); + let in_group = input.slice((.., in_chan_start..in_chan_end)); + let mut out_group = output.slice_mut((.., out_chans.clone())); let kernel = kernel.to_contiguous_in(pool); let kernel_mat = kernel - .slice_with([out_chans.clone()]) + .slice([out_chans.clone()]) .reshaped([out_channels_per_group, in_channels_per_group * k_h * k_w]); // Prepack kernel if we'll be able to reuse packed weights. @@ -356,7 +356,7 @@ fn col2im( for out_c in 0..out_chans { // Initialize each output channel just before we accumulate into it. - let mut out_img = output.slice_with_mut([out_c]); + let mut out_img = output.slice_mut([out_c]); out_img.fill(MaybeUninit::new(bias.map(|b| b[[out_c]]).unwrap_or(0.))); // Safety: We just initialized all elements of `out_img`. @@ -364,7 +364,7 @@ fn col2im( for k_y in 0..kernel_h { for k_x in 0..kernel_w { - let in_img = columns.slice_with([out_c, k_y, k_x]); + let in_img = columns.slice([out_c, k_y, k_x]); let [img_h, img_w] = in_img.shape(); for y in 0..img_h { @@ -536,7 +536,7 @@ pub fn conv_transpose( // The implementation here is the inverse of the im2col-based convolution. let mut n_init = 0; for n in 0..batch { - let input_mat = input.slice_with([n]).reshaped([in_c, in_h * in_w]); + let input_mat = input.slice([n]).reshaped([in_c, in_h * in_w]); let col2im_row_stride = col2im_mat.stride(0); gemm.gemm_uninit( @@ -549,7 +549,7 @@ pub fn conv_transpose( // Safety: `gemm_uninit` initialized col2im_mat. let col2im_mat = unsafe { col2im_mat.view().assume_init() }; - let mut out_img = output.slice_with_mut(n); + let mut out_img = output.slice_mut(n); col2im( &mut out_img, @@ -1635,7 +1635,7 @@ mod tests { // With padding. run_bench(100, Some("col2im"), || { col2im( - &mut output.slice_with_mut((.., 2.., 2..)), + &mut output.slice_mut((.., 2.., 2..)), &columns.view(), [1, 1, 1, 1], // Padding [stride_y, stride_x], diff --git a/src/ops/conv/depthwise.rs b/src/ops/conv/depthwise.rs index a12c64b0..6006a57d 100644 --- a/src/ops/conv/depthwise.rs +++ b/src/ops/conv/depthwise.rs @@ -83,16 +83,16 @@ fn conv_2d_depthwise_block( let [dilation_y, _dilation_x] = dilations; for c in chan_range.clone() { - let kernel_view = kernel.slice_with([c, 0]).weakly_checked_view(); + let kernel_view = kernel.slice([c, 0]).weakly_checked_view(); // For efficiency, use manual slicing in the inner loops to extract // input/output rows. - let mut out_chan = output.slice_with_mut([c - chan_range.start]); + let mut out_chan = output.slice_mut([c - chan_range.start]); let out_row_stride = out_chan.stride(0); let out_row_len = out_chan.size(1); let out_chan_data = out_chan.data_mut().unwrap(); - let in_chan = input.slice_with([c]); + let in_chan = input.slice([c]); let in_row_stride = in_chan.stride(0); let in_row_len = in_chan.size(1); let in_chan_data = in_chan.data().unwrap(); @@ -204,8 +204,8 @@ where let n_init = AtomicUsize::new(0); for n in 0..batch { - let mut out_chans = output.slice_with_mut(n); - let input = input.slice_with(n); + let mut out_chans = output.slice_mut(n); + let input = input.slice(n); out_chans .axis_chunks_mut(0, channel_chunk_size) diff --git a/src/ops/einsum.rs b/src/ops/einsum.rs index ccdc3699..19c7cf26 100644 --- a/src/ops/einsum.rs +++ b/src/ops/einsum.rs @@ -866,16 +866,16 @@ mod tests { // Matrix-vector product Case { equation: "ij,j->i", - inputs: vec![mat_a.view(), mat_b.slice_dyn((.., 0))], - expected: Ok(matmul(&pool, mat_a.view(), mat_b.slice_dyn((.., ..1))) + inputs: vec![mat_a.view(), mat_b.slice((.., 0))], + expected: Ok(matmul(&pool, mat_a.view(), mat_b.slice((.., ..1))) .unwrap() .into_shape([mat_a.size(0)].as_slice())), }, // Vector-matrix product Case { equation: "j,jk->k", - inputs: vec![mat_a.slice_dyn(0), mat_b.view()], - expected: Ok(matmul(&pool, mat_a.slice_dyn((..1, ..)), mat_b.view()) + inputs: vec![mat_a.slice(0), mat_b.view()], + expected: Ok(matmul(&pool, mat_a.slice((..1, ..)), mat_b.view()) .unwrap() .into_shape([mat_b.size(1)].as_slice())), }, @@ -895,11 +895,11 @@ mod tests { // are not present in all tensors. Case { equation: "ij,j->", - inputs: vec![mat_a.view(), mat_b.slice_dyn((.., 0))], + inputs: vec![mat_a.view(), mat_b.slice((.., 0))], expected: Ok(Tensor::from( mat_a .iter() - .zip(mat_b.slice_dyn((.., 0)).broadcast(mat_a.shape()).iter()) + .zip(mat_b.slice((.., 0)).broadcast(mat_a.shape()).iter()) .map(|(x, y)| x * y) .sum::(), )), diff --git a/src/ops/gather.rs b/src/ops/gather.rs index ec96043b..51d1e564 100644 --- a/src/ops/gather.rs +++ b/src/ops/gather.rs @@ -40,7 +40,7 @@ pub fn gather( let mut slice_range = full_range(input.ndim()); slice_range[axis] = SliceItem::Index(*index as isize); let slice = input - .try_slice_with(slice_range.as_slice()) + .try_slice(slice_range.as_slice()) .map_err(|_| INVALID_INDEX_ERR)?; slice.to_tensor_in(pool) }; @@ -64,9 +64,9 @@ pub fn gather( out_range[axis + i] = SliceItem::Index(index_val as isize); } let in_slice = input - .try_slice_with(in_range.as_slice()) + .try_slice(in_range.as_slice()) .map_err(|_| INVALID_INDEX_ERR)?; - let mut out_slice = output.slice_mut_dyn(out_range.as_slice()); + let mut out_slice = output.slice_mut(out_range.as_slice()); out_slice.copy_from(&in_slice); } @@ -304,7 +304,7 @@ pub fn gather_nd( for (out_slice, idx) in out_slices.zip(idx_slices) { let slice_items = to_slice_items(idx); let in_slice = input - .try_slice_with(slice_items.as_slice()) + .try_slice(slice_items.as_slice()) .map_err(|_| OpError::InvalidValue("Invalid index"))?; for (out, x) in out_slice.iter_mut().zip(in_slice.iter()) { diff --git a/src/ops/non_max_suppression.rs b/src/ops/non_max_suppression.rs index 15cf6dd4..eefc032d 100644 --- a/src/ops/non_max_suppression.rs +++ b/src/ops/non_max_suppression.rs @@ -93,7 +93,7 @@ pub fn non_max_suppression( for n in 0..batch { for b in 0..n_boxes { let (max_score_cls, max_score) = scores - .slice_with((n, .., b)) + .slice((n, .., b)) .iter() .copied() .enumerate() @@ -104,7 +104,7 @@ pub fn non_max_suppression( continue; } - let [c0, c1, c2, c3] = boxes.slice_with((n, b)).to_array(); + let [c0, c1, c2, c3] = boxes.slice((n, b)).to_array(); let [top, left, bottom, right] = match box_order { BoxOrder::TopLeftBottomRight => [c0, c1, c2, c3], BoxOrder::CenterWidthHeight => { @@ -172,7 +172,7 @@ pub fn non_max_suppression( let mut selected_indices = NdTensor::zeros_in(pool, [selected.len(), 3]); for (i, nms_box) in selected.into_iter().enumerate() { - selected_indices.slice_with_mut(i).assign_array([ + selected_indices.slice_mut(i).assign_array([ nms_box.batch_index as i32, nms_box.class as i32, nms_box.box_index as i32, @@ -255,7 +255,7 @@ mod tests { [cx, cy, w, h] } }; - out_boxes.slice_with_mut((0, i)).assign_array(coords); + out_boxes.slice_mut((0, i)).assign_array(coords); out_scores[[0, nms_box.class, i]] = nms_box.score; } @@ -309,10 +309,10 @@ mod tests { assert_eq!(selected.size(0), 2); - let [batch, class, box_idx] = selected.slice_with(0).to_array(); + let [batch, class, box_idx] = selected.slice(0).to_array(); assert_eq!([batch, class, box_idx], [0, 0, 0]); - let [batch, class, box_idx] = selected.slice_with(1).to_array(); + let [batch, class, box_idx] = selected.slice(1).to_array(); assert_eq!([batch, class, box_idx], [0, 1, 2]); } @@ -371,10 +371,10 @@ mod tests { // returned. assert_eq!(selected.size(0), 3); - let [batch, class, box_idx] = selected.slice_with(0).to_array(); + let [batch, class, box_idx] = selected.slice(0).to_array(); assert_eq!([batch, class, box_idx], [0, 0, 0]); - let [batch, class, box_idx] = selected.slice_with(1).to_array(); + let [batch, class, box_idx] = selected.slice(1).to_array(); assert_eq!([batch, class, box_idx], [0, 0, 1]); } @@ -400,10 +400,10 @@ mod tests { // be returned from each class. assert!(selected.size(0) == 2); - let [batch, class, box_idx] = selected.slice_with(0).to_array(); + let [batch, class, box_idx] = selected.slice(0).to_array(); assert_eq!([batch, class, box_idx], [0, 0, 0]); - let [batch, class, box_idx] = selected.slice_with(1).to_array(); + let [batch, class, box_idx] = selected.slice(1).to_array(); assert_eq!([batch, class, box_idx], [0, 1, 2]); } @@ -428,7 +428,7 @@ mod tests { // Only the box with score exceeding `score_threshold` will be returned. assert!(selected.size(0) == 1); - let [batch, class, box_idx] = selected.slice_with(0).to_array(); + let [batch, class, box_idx] = selected.slice(0).to_array(); assert_eq!([batch, class, box_idx], [0, 0, 0]); } diff --git a/src/ops/norm.rs b/src/ops/norm.rs index b4ac0469..d32a1787 100644 --- a/src/ops/norm.rs +++ b/src/ops/norm.rs @@ -45,7 +45,7 @@ pub fn batch_norm_in_place( let scaled_std_dev_reciprocal = chan_scale / (chan_var + epsilon).sqrt(); input - .slice_mut_dyn([n, c]) + .slice_mut([n, c]) .apply(|el| (*el - chan_mean) * scaled_std_dev_reciprocal + chan_bias); } } @@ -169,7 +169,7 @@ pub fn instance_normalization_in_place( for n in 0..batch { for c in 0..chans { - let mut slice = input.slice_mut_dyn([n, c]); + let mut slice = input.slice_mut([n, c]); let chan_scale = scale[[c]]; let chan_bias = bias[[c]]; let chan_mean = slice_sum(slice.data().unwrap()) / slice.len() as f32; diff --git a/src/ops/pad.rs b/src/ops/pad.rs index dcb990bc..26efe2cc 100644 --- a/src/ops/pad.rs +++ b/src/ops/pad.rs @@ -57,7 +57,7 @@ pub fn pad( let mut output = Tensor::full_in(pool, &out_shape, const_val); output - .slice_mut_dyn(non_pad_region.as_slice()) + .slice_mut(non_pad_region.as_slice()) .copy_from(&input); output } diff --git a/src/ops/pooling.rs b/src/ops/pooling.rs index 823554be..e864ae7d 100644 --- a/src/ops/pooling.rs +++ b/src/ops/pooling.rs @@ -303,9 +303,9 @@ pub fn global_average_pool(pool: &TensorPool, input: TensorView) -> Result Result>( SliceItem::Index(idx as isize) } })); - let slice = input.slice_with(inner_range.as_slice()); + let slice = input.slice(inner_range.as_slice()); let reduced = reducer.reduce(slice.iter().copied()); reduced_data.push(reduced); } diff --git a/src/ops/resize.rs b/src/ops/resize.rs index df900fe8..13e1ccf4 100644 --- a/src/ops/resize.rs +++ b/src/ops/resize.rs @@ -300,9 +300,9 @@ pub fn resize( let n_init = AtomicUsize::new(0); for n in 0..batch { - let in_image = input.slice_with([n]); + let in_image = input.slice([n]); let mut out_batch = output.nd_view_mut::<4>(); - let mut out_image = out_batch.slice_with_mut([n]); + let mut out_image = out_batch.slice_mut([n]); out_image .axis_chunks_mut(0, CHAN_GROUP_SIZE) diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index 178e5920..5f6ce525 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -180,7 +180,7 @@ pub fn gru( for dir in 0..num_directions { let prepack = seq_len >= PREPACK_MIN_SEQ_LEN; - let input_weights = weights.slice_with(dir).transposed(); + let input_weights = weights.slice(dir).transposed(); let packed_input_weights = prepack.then(|| gemm.prepack_b_in(pool, input_weights).auto_return(pool)); let input_weights = packed_input_weights @@ -188,7 +188,7 @@ pub fn gru( .map(|packed| GemmInputB::Packed(packed)) .unwrap_or(GemmInputB::Unpacked(input_weights)); - let hidden_weights = recurrent_weights.slice_with(dir).transposed(); + let hidden_weights = recurrent_weights.slice(dir).transposed(); let packed_hidden_weights = prepack.then(|| gemm.prepack_b_in(pool, hidden_weights).auto_return(pool)); let hidden_weights = packed_hidden_weights @@ -198,14 +198,14 @@ pub fn gru( let input_bias = bias .as_ref() - .map(|b| b.slice_with((dir, ..(n_gates * hidden_size)))); + .map(|b| b.slice((dir, ..(n_gates * hidden_size)))); let hidden_bias = bias .as_ref() - .map(|b| b.slice_with((dir, (n_gates * hidden_size)..))); + .map(|b| b.slice((dir, (n_gates * hidden_size)..))); for seq in sequence_for_dir(direction, dir, seq_len) { - let in_item = input.slice_with([seq]); - let hidden_item = hidden.slice_with([dir]); + let in_item = input.slice([seq]); + let hidden_item = hidden.slice([dir]); // From the ONNX spec, the intermediate values are computed as: // @@ -261,11 +261,11 @@ pub fn gru( } // Combine inputs for reset and update gates and apply activation. - let mut update_reset_gates = gates.slice_with_mut(( + let mut update_reset_gates = gates.slice_mut(( .., gate_range(UPDATE_GATE).start..gate_range(RESET_GATE).end, )); - let hidden_scratch_reset_update_gates = hidden_scratch.slice_with(( + let hidden_scratch_reset_update_gates = hidden_scratch.slice(( .., gate_range(UPDATE_GATE).start..gate_range(RESET_GATE).end, )); @@ -284,22 +284,21 @@ pub fn gru( // as `gates`. let update_reset_gates = sigmoid(pool, update_reset_gates.as_dyn()).auto_return(pool); let update_reset_gates = update_reset_gates.nd_view::<2>(); - let update_gate = update_reset_gates.slice_with((.., gate_range(UPDATE_GATE))); - let reset_gate = update_reset_gates.slice_with((.., gate_range(RESET_GATE))); + let update_gate = update_reset_gates.slice((.., gate_range(UPDATE_GATE))); + let reset_gate = update_reset_gates.slice((.., gate_range(RESET_GATE))); // Combine inputs for hidden gate and apply activation. - let mut hidden_gate_recurrent = - hidden_scratch.slice_with_mut((.., gate_range(HIDDEN_GATE))); + let mut hidden_gate_recurrent = hidden_scratch.slice_mut((.., gate_range(HIDDEN_GATE))); mul_in_place(hidden_gate_recurrent.as_dyn_mut(), reset_gate.as_dyn()); - let mut hidden_gate = gates.slice_with_mut((.., gate_range(HIDDEN_GATE))); + let mut hidden_gate = gates.slice_mut((.., gate_range(HIDDEN_GATE))); add_in_place(hidden_gate.as_dyn_mut(), hidden_gate_recurrent.as_dyn()); // See note above about `sigmoid_in_place`. let hidden_gate = tanh(pool, hidden_gate.as_dyn()).auto_return(pool); // Compute next hidden state - let mut hidden_item = hidden.slice_with_mut([dir]); + let mut hidden_item = hidden.slice_mut([dir]); for (hidden, update, hidden_gate) in zip3( hidden_item.iter_mut(), @@ -309,9 +308,7 @@ pub fn gru( *hidden = (1. - update) * hidden_gate + update * (*hidden); } - hidden_seq - .slice_with_mut([seq, dir]) - .copy_from(&hidden_item); + hidden_seq.slice_mut([seq, dir]).copy_from(&hidden_item); } } @@ -442,7 +439,7 @@ pub fn lstm( for dir in 0..num_directions { let prepack = seq_len >= PREPACK_MIN_SEQ_LEN; - let input_weights = weights.slice_with(dir).transposed(); + let input_weights = weights.slice(dir).transposed(); let packed_input_weights = prepack.then(|| gemm.prepack_b_in(pool, input_weights).auto_return(pool)); let input_weights = packed_input_weights @@ -450,7 +447,7 @@ pub fn lstm( .map(|packed| GemmInputB::Packed(packed)) .unwrap_or(GemmInputB::Unpacked(input_weights)); - let hidden_weights = recurrent_weights.slice_with(dir).transposed(); + let hidden_weights = recurrent_weights.slice(dir).transposed(); let packed_hidden_weights = prepack.then(|| gemm.prepack_b_in(pool, hidden_weights).auto_return(pool)); let hidden_weights = packed_hidden_weights @@ -460,10 +457,10 @@ pub fn lstm( let input_bias = bias .as_ref() - .map(|b| b.slice_with((dir, ..(n_gates * hidden_size)))); + .map(|b| b.slice((dir, ..(n_gates * hidden_size)))); let hidden_bias = bias .as_ref() - .map(|b| b.slice_with((dir, (n_gates * hidden_size)..))); + .map(|b| b.slice((dir, (n_gates * hidden_size)..))); for seq in sequence_for_dir(direction, dir, seq_len) { // From the ONNX spec, the intermediate values are computed as: @@ -485,8 +482,8 @@ pub fn lstm( // supported. // - `f`, `g` and `h` are activations. `f`=sigmoid, `g` and `h` // are tanh. - let in_item = input.slice_with([seq]); - let hidden_item = hidden.slice_with([dir]); + let in_item = input.slice([seq]); + let hidden_item = hidden.slice([dir]); // Update input, output, forget and cell gates. let gates_row_stride = gates.stride(gates.ndim() - 2); @@ -516,22 +513,22 @@ pub fn lstm( // Copy gates to work around `tanh_in_place` and `sigmoid_in_place` // being slow for non-contiguous inputs. See notes in GRU op. - let iof_gates = gates.slice_with(( + let iof_gates = gates.slice(( .., gate_range(INPUT_GATE).start..gate_range(FORGET_GATE).end, )); let iof_gates = sigmoid(pool, iof_gates.as_dyn()).auto_return(pool); let iof_gates = iof_gates.nd_view::<2>(); - let input_gate = iof_gates.slice_with((.., gate_range(INPUT_GATE))); - let out_gate = iof_gates.slice_with((.., gate_range(OUTPUT_GATE))); - let forget_gate = iof_gates.slice_with((.., gate_range(FORGET_GATE))); + let input_gate = iof_gates.slice((.., gate_range(INPUT_GATE))); + let out_gate = iof_gates.slice((.., gate_range(OUTPUT_GATE))); + let forget_gate = iof_gates.slice((.., gate_range(FORGET_GATE))); - let cell_gate = gates.slice_with((.., gate_range(CELL_GATE))); + let cell_gate = gates.slice((.., gate_range(CELL_GATE))); let cell_gate = tanh(pool, cell_gate.as_dyn()).auto_return(pool); // Update cell and hidden state - let mut cell_item = cell.slice_with_mut([dir]); + let mut cell_item = cell.slice_mut([dir]); for (cell, forget_gate, input_gate, cell_gate) in zip4( cell_item.iter_mut(), @@ -542,16 +539,14 @@ pub fn lstm( *cell = forget_gate * *cell + input_gate * cell_gate; } - let mut hidden_item = hidden.slice_with_mut([dir]); + let mut hidden_item = hidden.slice_mut([dir]); for (hidden, out_gate, cell) in zip3(hidden_item.iter_mut(), out_gate.iter(), cell_item.iter()) { *hidden = out_gate * cell.tanh() } - hidden_seq - .slice_with_mut([seq, dir]) - .copy_from(&hidden_item); + hidden_seq.slice_mut([seq, dir]).copy_from(&hidden_item); } } @@ -771,18 +766,18 @@ mod tests { // The last hidden state should match the end of the hidden sequence // for the forwards direction, and the start of the hidden sequence // for the reverse direction. - let hidden_seq_fwd = hidden_seq.slice_with(( + let hidden_seq_fwd = hidden_seq.slice(( -1, // seq 0, // direction )); - let last_hidden_fwd = last_hidden.slice_with(0); + let last_hidden_fwd = last_hidden.slice(0); assert_eq!(hidden_seq_fwd, last_hidden_fwd); - let hidden_seq_rev = hidden_seq.slice_with(( + let hidden_seq_rev = hidden_seq.slice(( 0, // seq 1, // direction )); - let last_hidden_rev = last_hidden.slice_with(1); + let last_hidden_rev = last_hidden.slice(1); assert_eq!(hidden_seq_rev, last_hidden_rev); } } diff --git a/src/ops/split.rs b/src/ops/split.rs index c6790e87..bac24b9b 100644 --- a/src/ops/split.rs +++ b/src/ops/split.rs @@ -40,7 +40,7 @@ pub fn split( split_start += split_size; - input.slice_with(slice_range.as_slice()).to_tensor_in(pool) + input.slice(slice_range.as_slice()).to_tensor_in(pool) }) .collect();