diff --git a/rten-imageproc/src/drawing.rs b/rten-imageproc/src/drawing.rs index 47e34b3a..d3755c1b 100644 --- a/rten-imageproc/src/drawing.rs +++ b/rten-imageproc/src/drawing.rs @@ -465,7 +465,7 @@ impl<'a, T: Copy + Default> Painter<'a, T> { pub fn draw_polygon(&mut self, points: &[Point]) { for i in 0..3 { draw_polygon( - self.surface.slice_mut([i]), + self.surface.slice_with_mut([i]), points, self.state.stroke[i], self.state.stroke_width, @@ -600,9 +600,9 @@ mod tests { let expected_g = expected_r.map(|&x| if x == r { g } else { 0 }); let expected_b = expected_r.map(|&x| if x == r { b } else { 0 }); - compare_images(img.slice([0]), expected_r.view()); - compare_images(img.slice([1]), expected_g.view()); - compare_images(img.slice([2]), expected_b.view()); + compare_images(img.slice_with([0]), expected_r.view()); + compare_images(img.slice_with([1]), expected_g.view()); + compare_images(img.slice_with([2]), expected_b.view()); } #[test] diff --git a/rten-imageproc/src/normalize.rs b/rten-imageproc/src/normalize.rs index ce64a644..624c93b6 100644 --- a/rten-imageproc/src/normalize.rs +++ b/rten-imageproc/src/normalize.rs @@ -32,7 +32,7 @@ pub fn normalize_image( for chan in 0..n_chans { let inv_std_dev = 1. / std_dev[chan]; - img.slice_mut::<2, _>(chan) + img.slice_with_mut(chan) .apply(|x| (x - mean[chan]) * inv_std_dev); } } diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs index 4ffc1ed2..1196b0f0 100644 --- a/rten-tensor/src/tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -2434,8 +2434,8 @@ mod tests { let mut transposed = tensor.view_mut(); transposed.permute([1, 0]); - transposed.slice_mut(0).assign_array([1, 2]); - transposed.slice_mut(1).assign_array([3, 4]); + transposed.slice_with_mut(0).assign_array([1, 2]); + transposed.slice_with_mut(1).assign_array([3, 4]); assert_eq!(tensor.iter().copied().collect::>(), [1, 3, 2, 4]); } diff --git a/src/ops/conv.rs b/src/ops/conv.rs index 90de21e9..51eaf82c 100644 --- a/src/ops/conv.rs +++ b/src/ops/conv.rs @@ -32,7 +32,7 @@ where { let [batch, _, in_h, in_w]: [usize; 4] = input.shape(); let [out_c, in_c, _, _]: [usize; 4] = kernel.shape(); - let mut output = Tensor::uninit_in(pool, &[batch, out_c, in_h * in_w]); + let mut output = NdTensor::uninit_in(pool, [batch, out_c, in_h * in_w]); // Get input and kernel as contiguous tensors so we can create reshaped // views. @@ -47,7 +47,7 @@ where let mut n_init = 0; for n in 0..batch { - let mut out_item = output.slice_mut::<2, _>([n]); + let mut out_item = output.slice_with_mut([n]); let out_row_stride = out_item.stride(0); let in_mat = input.slice_with([n]).reshaped([in_c, in_h * in_w]); @@ -63,11 +63,11 @@ where n_init += out_item.len(); } - output.reshape(&[batch, out_c, in_h, in_w]); + let output = output.into_shape([batch, out_c, in_h, in_w]); // Safety: We used `gemm_uninit_bias` to initialize all elements. assert!(n_init == output.len()); - unsafe { output.assume_init() } + unsafe { output.assume_init().into_dyn() } } /// Perform a convolution of `input` with `kernel`. @@ -355,7 +355,7 @@ fn col2im( for out_c in 0..out_chans { // Initialize each output channel just before we accumulate into it. - let mut out_img = output.slice_mut([out_c]); + let mut out_img = output.slice_with_mut([out_c]); out_img.fill(MaybeUninit::new(bias.map(|b| b[[out_c]]).unwrap_or(0.))); // Safety: We just initialized all elements of `out_img`. @@ -363,7 +363,7 @@ fn col2im( for k_y in 0..kernel_h { for k_x in 0..kernel_w { - let in_img = columns.slice([out_c, k_y, k_x]); + let in_img = columns.slice_with([out_c, k_y, k_x]); let [img_h, img_w] = in_img.shape(); for y in 0..img_h { @@ -521,7 +521,7 @@ pub fn conv_transpose( let [out_h, out_w] = out_shape; let [pad_top, pad_left, pad_bottom, pad_right] = fixed_padding; - let mut output = Tensor::uninit_in(pool, [batch, out_c, out_h, out_w].as_slice()); + let mut output = NdTensor::uninit_in(pool, [batch, out_c, out_h, out_w]); // Ensure input and kernel are contiguous to support reshaping. let input = input.to_contiguous_in(pool).auto_return(pool); @@ -548,7 +548,7 @@ pub fn conv_transpose( // Safety: `gemm_uninit` initialized col2im_mat. let col2im_mat = unsafe { col2im_mat.view().assume_init() }; - let mut out_img = output.slice_mut(n); + let mut out_img = output.slice_with_mut(n); col2im( &mut out_img, @@ -562,7 +562,7 @@ pub fn conv_transpose( assert!(n_init == output.len()); let output = unsafe { output.assume_init() }; - Ok(output) + Ok(output.into_dyn()) } #[derive(Debug)] @@ -1634,7 +1634,7 @@ mod tests { // With padding. run_bench(100, Some("col2im"), || { col2im( - &mut output.slice_mut((.., 2.., 2..)), + &mut output.slice_with_mut((.., 2.., 2..)), &columns.view(), [1, 1, 1, 1], // Padding [stride_y, stride_x], diff --git a/src/ops/conv/depthwise.rs b/src/ops/conv/depthwise.rs index 8ecedb3e..a12c64b0 100644 --- a/src/ops/conv/depthwise.rs +++ b/src/ops/conv/depthwise.rs @@ -83,11 +83,11 @@ fn conv_2d_depthwise_block( let [dilation_y, _dilation_x] = dilations; for c in chan_range.clone() { - let kernel_view = kernel.slice([c, 0]).weakly_checked_view(); + let kernel_view = kernel.slice_with([c, 0]).weakly_checked_view(); // For efficiency, use manual slicing in the inner loops to extract // input/output rows. - let mut out_chan = output.slice_mut::<2, _>([c - chan_range.start]); + let mut out_chan = output.slice_with_mut([c - chan_range.start]); let out_row_stride = out_chan.stride(0); let out_row_len = out_chan.size(1); let out_chan_data = out_chan.data_mut().unwrap(); diff --git a/src/ops/non_max_suppression.rs b/src/ops/non_max_suppression.rs index 40c7d249..15cf6dd4 100644 --- a/src/ops/non_max_suppression.rs +++ b/src/ops/non_max_suppression.rs @@ -104,7 +104,7 @@ pub fn non_max_suppression( continue; } - let [c0, c1, c2, c3] = boxes.slice((n, b)).to_array(); + let [c0, c1, c2, c3] = boxes.slice_with((n, b)).to_array(); let [top, left, bottom, right] = match box_order { BoxOrder::TopLeftBottomRight => [c0, c1, c2, c3], BoxOrder::CenterWidthHeight => { @@ -172,7 +172,7 @@ pub fn non_max_suppression( let mut selected_indices = NdTensor::zeros_in(pool, [selected.len(), 3]); for (i, nms_box) in selected.into_iter().enumerate() { - selected_indices.slice_mut(i).assign_array([ + selected_indices.slice_with_mut(i).assign_array([ nms_box.batch_index as i32, nms_box.class as i32, nms_box.box_index as i32, @@ -255,7 +255,7 @@ mod tests { [cx, cy, w, h] } }; - out_boxes.slice_mut((0, i)).assign_array(coords); + out_boxes.slice_with_mut((0, i)).assign_array(coords); out_scores[[0, nms_box.class, i]] = nms_box.score; } @@ -309,10 +309,10 @@ mod tests { assert_eq!(selected.size(0), 2); - let [batch, class, box_idx] = selected.slice(0).to_array(); + let [batch, class, box_idx] = selected.slice_with(0).to_array(); assert_eq!([batch, class, box_idx], [0, 0, 0]); - let [batch, class, box_idx] = selected.slice(1).to_array(); + let [batch, class, box_idx] = selected.slice_with(1).to_array(); assert_eq!([batch, class, box_idx], [0, 1, 2]); } @@ -371,10 +371,10 @@ mod tests { // returned. assert_eq!(selected.size(0), 3); - let [batch, class, box_idx] = selected.slice(0).to_array(); + let [batch, class, box_idx] = selected.slice_with(0).to_array(); assert_eq!([batch, class, box_idx], [0, 0, 0]); - let [batch, class, box_idx] = selected.slice(1).to_array(); + let [batch, class, box_idx] = selected.slice_with(1).to_array(); assert_eq!([batch, class, box_idx], [0, 0, 1]); } @@ -400,10 +400,10 @@ mod tests { // be returned from each class. assert!(selected.size(0) == 2); - let [batch, class, box_idx] = selected.slice(0).to_array(); + let [batch, class, box_idx] = selected.slice_with(0).to_array(); assert_eq!([batch, class, box_idx], [0, 0, 0]); - let [batch, class, box_idx] = selected.slice(1).to_array(); + let [batch, class, box_idx] = selected.slice_with(1).to_array(); assert_eq!([batch, class, box_idx], [0, 1, 2]); } @@ -428,7 +428,7 @@ mod tests { // Only the box with score exceeding `score_threshold` will be returned. assert!(selected.size(0) == 1); - let [batch, class, box_idx] = selected.slice(0).to_array(); + let [batch, class, box_idx] = selected.slice_with(0).to_array(); assert_eq!([batch, class, box_idx], [0, 0, 0]); } diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index d9371b7c..ffbe66ad 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -595,7 +595,7 @@ mod tests { use rten_tensor::prelude::*; use rten_tensor::rng::XorShiftRng; use rten_tensor::test_util::expect_equal; - use rten_tensor::Tensor; + use rten_tensor::{NdTensor, Tensor}; use serde_json::Value; use crate::ops::tests::new_pool; @@ -695,45 +695,46 @@ mod tests { Op::Lstm => 4, }; - let input = Tensor::::rand(&[seq_len, batch, features], &mut rng).map(|x| x - 0.5); - let weights = Tensor::::rand( - &[dir.num_directions(), num_gates * hidden_size, features], + let input = + NdTensor::::rand([seq_len, batch, features], &mut rng).map(|x| x - 0.5); + let weights = NdTensor::::rand( + [dir.num_directions(), num_gates * hidden_size, features], &mut rng, ) .map(|x| x - 0.5); - let recurrent_weights = Tensor::::rand( - &[dir.num_directions(), num_gates * hidden_size, hidden_size], + let recurrent_weights = NdTensor::::rand( + [dir.num_directions(), num_gates * hidden_size, hidden_size], &mut rng, ) .map(|x| x - 0.5); - let bias = Tensor::rand( - &[dir.num_directions(), 2 * num_gates * hidden_size], + let bias = NdTensor::rand( + [dir.num_directions(), 2 * num_gates * hidden_size], &mut rng, ); let initial_hidden = - Tensor::rand(&[dir.num_directions(), batch, hidden_size], &mut rng); - let initial_cell = Tensor::rand(&[dir.num_directions(), batch, hidden_size], &mut rng); + NdTensor::rand([dir.num_directions(), batch, hidden_size], &mut rng); + let initial_cell = NdTensor::rand([dir.num_directions(), batch, hidden_size], &mut rng); let result = match case.op { Op::Lstm => lstm( &pool, dir, - input.view(), - weights.view(), - recurrent_weights.view(), - case.with_bias.then_some(bias.view()), - case.with_hidden_init.then_some(initial_hidden.view()), - case.with_initial_cell.then_some(initial_cell.view()), + input.as_dyn(), + weights.as_dyn(), + recurrent_weights.as_dyn(), + case.with_bias.then_some(bias.as_dyn()), + case.with_hidden_init.then_some(initial_hidden.as_dyn()), + case.with_initial_cell.then_some(initial_cell.as_dyn()), ) .expect("lstm op failed"), Op::Gru => gru( &pool, dir, - input.view(), - weights.view(), - recurrent_weights.view(), - case.with_bias.then_some(bias.view()), - case.with_hidden_init.then_some(initial_hidden.view()), + input.as_dyn(), + weights.as_dyn(), + recurrent_weights.as_dyn(), + case.with_bias.then_some(bias.as_dyn()), + case.with_hidden_init.then_some(initial_hidden.as_dyn()), true, /* linear_before_reset */ ) .expect("gru op failed"), @@ -770,18 +771,18 @@ mod tests { // The last hidden state should match the end of the hidden sequence // for the forwards direction, and the start of the hidden sequence // for the reverse direction. - let hidden_seq_fwd = hidden_seq.slice::<2, _>(( + let hidden_seq_fwd = hidden_seq.slice_with(( -1, // seq 0, // direction )); - let last_hidden_fwd = last_hidden.slice::<2, _>(0); + let last_hidden_fwd = last_hidden.slice_with(0); assert_eq!(hidden_seq_fwd, last_hidden_fwd); - let hidden_seq_rev = hidden_seq.slice::<2, _>(( + let hidden_seq_rev = hidden_seq.slice_with(( 0, // seq 1, // direction )); - let last_hidden_rev = last_hidden.slice::<2, _>(1); + let last_hidden_rev = last_hidden.slice_with(1); assert_eq!(hidden_seq_rev, last_hidden_rev); } }