Skip to content

Commit

Permalink
Merge pull request #362 from robertknight/use-slice-with
Browse files Browse the repository at this point in the history
Replace remaining uses of `slice` with `slice_with`
  • Loading branch information
robertknight authored Sep 17, 2024
2 parents 94d5196 + 1014bca commit a57f594
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 54 deletions.
8 changes: 4 additions & 4 deletions rten-imageproc/src/drawing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ impl<'a, T: Copy + Default> Painter<'a, T> {
pub fn draw_polygon(&mut self, points: &[Point]) {
for i in 0..3 {
draw_polygon(
self.surface.slice_mut([i]),
self.surface.slice_with_mut([i]),
points,
self.state.stroke[i],
self.state.stroke_width,
Expand Down Expand Up @@ -600,9 +600,9 @@ mod tests {
let expected_g = expected_r.map(|&x| if x == r { g } else { 0 });
let expected_b = expected_r.map(|&x| if x == r { b } else { 0 });

compare_images(img.slice([0]), expected_r.view());
compare_images(img.slice([1]), expected_g.view());
compare_images(img.slice([2]), expected_b.view());
compare_images(img.slice_with([0]), expected_r.view());
compare_images(img.slice_with([1]), expected_g.view());
compare_images(img.slice_with([2]), expected_b.view());
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion rten-imageproc/src/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub fn normalize_image<const C: usize>(

for chan in 0..n_chans {
let inv_std_dev = 1. / std_dev[chan];
img.slice_mut::<2, _>(chan)
img.slice_with_mut(chan)
.apply(|x| (x - mean[chan]) * inv_std_dev);
}
}
4 changes: 2 additions & 2 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2434,8 +2434,8 @@ mod tests {
let mut transposed = tensor.view_mut();

transposed.permute([1, 0]);
transposed.slice_mut(0).assign_array([1, 2]);
transposed.slice_mut(1).assign_array([3, 4]);
transposed.slice_with_mut(0).assign_array([1, 2]);
transposed.slice_with_mut(1).assign_array([3, 4]);

assert_eq!(tensor.iter().copied().collect::<Vec<_>>(), [1, 3, 2, 4]);
}
Expand Down
20 changes: 10 additions & 10 deletions src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ where
{
let [batch, _, in_h, in_w]: [usize; 4] = input.shape();
let [out_c, in_c, _, _]: [usize; 4] = kernel.shape();
let mut output = Tensor::uninit_in(pool, &[batch, out_c, in_h * in_w]);
let mut output = NdTensor::uninit_in(pool, [batch, out_c, in_h * in_w]);

// Get input and kernel as contiguous tensors so we can create reshaped
// views.
Expand All @@ -47,7 +47,7 @@ where
let mut n_init = 0;

for n in 0..batch {
let mut out_item = output.slice_mut::<2, _>([n]);
let mut out_item = output.slice_with_mut([n]);
let out_row_stride = out_item.stride(0);

let in_mat = input.slice_with([n]).reshaped([in_c, in_h * in_w]);
Expand All @@ -63,11 +63,11 @@ where
n_init += out_item.len();
}

output.reshape(&[batch, out_c, in_h, in_w]);
let output = output.into_shape([batch, out_c, in_h, in_w]);

// Safety: We used `gemm_uninit_bias` to initialize all elements.
assert!(n_init == output.len());
unsafe { output.assume_init() }
unsafe { output.assume_init().into_dyn() }
}

/// Perform a convolution of `input` with `kernel`.
Expand Down Expand Up @@ -355,15 +355,15 @@ fn col2im(

for out_c in 0..out_chans {
// Initialize each output channel just before we accumulate into it.
let mut out_img = output.slice_mut([out_c]);
let mut out_img = output.slice_with_mut([out_c]);
out_img.fill(MaybeUninit::new(bias.map(|b| b[[out_c]]).unwrap_or(0.)));

// Safety: We just initialized all elements of `out_img`.
let mut out_img = unsafe { out_img.assume_init() };

for k_y in 0..kernel_h {
for k_x in 0..kernel_w {
let in_img = columns.slice([out_c, k_y, k_x]);
let in_img = columns.slice_with([out_c, k_y, k_x]);
let [img_h, img_w] = in_img.shape();

for y in 0..img_h {
Expand Down Expand Up @@ -521,7 +521,7 @@ pub fn conv_transpose(
let [out_h, out_w] = out_shape;
let [pad_top, pad_left, pad_bottom, pad_right] = fixed_padding;

let mut output = Tensor::uninit_in(pool, [batch, out_c, out_h, out_w].as_slice());
let mut output = NdTensor::uninit_in(pool, [batch, out_c, out_h, out_w]);

// Ensure input and kernel are contiguous to support reshaping.
let input = input.to_contiguous_in(pool).auto_return(pool);
Expand All @@ -548,7 +548,7 @@ pub fn conv_transpose(

// Safety: `gemm_uninit` initialized col2im_mat.
let col2im_mat = unsafe { col2im_mat.view().assume_init() };
let mut out_img = output.slice_mut(n);
let mut out_img = output.slice_with_mut(n);

col2im(
&mut out_img,
Expand All @@ -562,7 +562,7 @@ pub fn conv_transpose(

assert!(n_init == output.len());
let output = unsafe { output.assume_init() };
Ok(output)
Ok(output.into_dyn())
}

#[derive(Debug)]
Expand Down Expand Up @@ -1634,7 +1634,7 @@ mod tests {
// With padding.
run_bench(100, Some("col2im"), || {
col2im(
&mut output.slice_mut((.., 2.., 2..)),
&mut output.slice_with_mut((.., 2.., 2..)),
&columns.view(),
[1, 1, 1, 1], // Padding
[stride_y, stride_x],
Expand Down
4 changes: 2 additions & 2 deletions src/ops/conv/depthwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ fn conv_2d_depthwise_block<X, W, Y>(
let [dilation_y, _dilation_x] = dilations;

for c in chan_range.clone() {
let kernel_view = kernel.slice([c, 0]).weakly_checked_view();
let kernel_view = kernel.slice_with([c, 0]).weakly_checked_view();

// For efficiency, use manual slicing in the inner loops to extract
// input/output rows.
let mut out_chan = output.slice_mut::<2, _>([c - chan_range.start]);
let mut out_chan = output.slice_with_mut([c - chan_range.start]);
let out_row_stride = out_chan.stride(0);
let out_row_len = out_chan.size(1);
let out_chan_data = out_chan.data_mut().unwrap();
Expand Down
20 changes: 10 additions & 10 deletions src/ops/non_max_suppression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub fn non_max_suppression(
continue;
}

let [c0, c1, c2, c3] = boxes.slice((n, b)).to_array();
let [c0, c1, c2, c3] = boxes.slice_with((n, b)).to_array();
let [top, left, bottom, right] = match box_order {
BoxOrder::TopLeftBottomRight => [c0, c1, c2, c3],
BoxOrder::CenterWidthHeight => {
Expand Down Expand Up @@ -172,7 +172,7 @@ pub fn non_max_suppression(

let mut selected_indices = NdTensor::zeros_in(pool, [selected.len(), 3]);
for (i, nms_box) in selected.into_iter().enumerate() {
selected_indices.slice_mut(i).assign_array([
selected_indices.slice_with_mut(i).assign_array([
nms_box.batch_index as i32,
nms_box.class as i32,
nms_box.box_index as i32,
Expand Down Expand Up @@ -255,7 +255,7 @@ mod tests {
[cx, cy, w, h]
}
};
out_boxes.slice_mut((0, i)).assign_array(coords);
out_boxes.slice_with_mut((0, i)).assign_array(coords);
out_scores[[0, nms_box.class, i]] = nms_box.score;
}

Expand Down Expand Up @@ -309,10 +309,10 @@ mod tests {

assert_eq!(selected.size(0), 2);

let [batch, class, box_idx] = selected.slice(0).to_array();
let [batch, class, box_idx] = selected.slice_with(0).to_array();
assert_eq!([batch, class, box_idx], [0, 0, 0]);

let [batch, class, box_idx] = selected.slice(1).to_array();
let [batch, class, box_idx] = selected.slice_with(1).to_array();
assert_eq!([batch, class, box_idx], [0, 1, 2]);
}

Expand Down Expand Up @@ -371,10 +371,10 @@ mod tests {
// returned.
assert_eq!(selected.size(0), 3);

let [batch, class, box_idx] = selected.slice(0).to_array();
let [batch, class, box_idx] = selected.slice_with(0).to_array();
assert_eq!([batch, class, box_idx], [0, 0, 0]);

let [batch, class, box_idx] = selected.slice(1).to_array();
let [batch, class, box_idx] = selected.slice_with(1).to_array();
assert_eq!([batch, class, box_idx], [0, 0, 1]);
}

Expand All @@ -400,10 +400,10 @@ mod tests {
// be returned from each class.
assert!(selected.size(0) == 2);

let [batch, class, box_idx] = selected.slice(0).to_array();
let [batch, class, box_idx] = selected.slice_with(0).to_array();
assert_eq!([batch, class, box_idx], [0, 0, 0]);

let [batch, class, box_idx] = selected.slice(1).to_array();
let [batch, class, box_idx] = selected.slice_with(1).to_array();
assert_eq!([batch, class, box_idx], [0, 1, 2]);
}

Expand All @@ -428,7 +428,7 @@ mod tests {
// Only the box with score exceeding `score_threshold` will be returned.
assert!(selected.size(0) == 1);

let [batch, class, box_idx] = selected.slice(0).to_array();
let [batch, class, box_idx] = selected.slice_with(0).to_array();
assert_eq!([batch, class, box_idx], [0, 0, 0]);
}

Expand Down
51 changes: 26 additions & 25 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ mod tests {
use rten_tensor::prelude::*;
use rten_tensor::rng::XorShiftRng;
use rten_tensor::test_util::expect_equal;
use rten_tensor::Tensor;
use rten_tensor::{NdTensor, Tensor};
use serde_json::Value;

use crate::ops::tests::new_pool;
Expand Down Expand Up @@ -695,45 +695,46 @@ mod tests {
Op::Lstm => 4,
};

let input = Tensor::<f32>::rand(&[seq_len, batch, features], &mut rng).map(|x| x - 0.5);
let weights = Tensor::<f32>::rand(
&[dir.num_directions(), num_gates * hidden_size, features],
let input =
NdTensor::<f32, 3>::rand([seq_len, batch, features], &mut rng).map(|x| x - 0.5);
let weights = NdTensor::<f32, 3>::rand(
[dir.num_directions(), num_gates * hidden_size, features],
&mut rng,
)
.map(|x| x - 0.5);
let recurrent_weights = Tensor::<f32>::rand(
&[dir.num_directions(), num_gates * hidden_size, hidden_size],
let recurrent_weights = NdTensor::<f32, 3>::rand(
[dir.num_directions(), num_gates * hidden_size, hidden_size],
&mut rng,
)
.map(|x| x - 0.5);
let bias = Tensor::rand(
&[dir.num_directions(), 2 * num_gates * hidden_size],
let bias = NdTensor::rand(
[dir.num_directions(), 2 * num_gates * hidden_size],
&mut rng,
);
let initial_hidden =
Tensor::rand(&[dir.num_directions(), batch, hidden_size], &mut rng);
let initial_cell = Tensor::rand(&[dir.num_directions(), batch, hidden_size], &mut rng);
NdTensor::rand([dir.num_directions(), batch, hidden_size], &mut rng);
let initial_cell = NdTensor::rand([dir.num_directions(), batch, hidden_size], &mut rng);

let result = match case.op {
Op::Lstm => lstm(
&pool,
dir,
input.view(),
weights.view(),
recurrent_weights.view(),
case.with_bias.then_some(bias.view()),
case.with_hidden_init.then_some(initial_hidden.view()),
case.with_initial_cell.then_some(initial_cell.view()),
input.as_dyn(),
weights.as_dyn(),
recurrent_weights.as_dyn(),
case.with_bias.then_some(bias.as_dyn()),
case.with_hidden_init.then_some(initial_hidden.as_dyn()),
case.with_initial_cell.then_some(initial_cell.as_dyn()),
)
.expect("lstm op failed"),
Op::Gru => gru(
&pool,
dir,
input.view(),
weights.view(),
recurrent_weights.view(),
case.with_bias.then_some(bias.view()),
case.with_hidden_init.then_some(initial_hidden.view()),
input.as_dyn(),
weights.as_dyn(),
recurrent_weights.as_dyn(),
case.with_bias.then_some(bias.as_dyn()),
case.with_hidden_init.then_some(initial_hidden.as_dyn()),
true, /* linear_before_reset */
)
.expect("gru op failed"),
Expand Down Expand Up @@ -770,18 +771,18 @@ mod tests {
// The last hidden state should match the end of the hidden sequence
// for the forwards direction, and the start of the hidden sequence
// for the reverse direction.
let hidden_seq_fwd = hidden_seq.slice::<2, _>((
let hidden_seq_fwd = hidden_seq.slice_with((
-1, // seq
0, // direction
));
let last_hidden_fwd = last_hidden.slice::<2, _>(0);
let last_hidden_fwd = last_hidden.slice_with(0);
assert_eq!(hidden_seq_fwd, last_hidden_fwd);

let hidden_seq_rev = hidden_seq.slice::<2, _>((
let hidden_seq_rev = hidden_seq.slice_with((
0, // seq
1, // direction
));
let last_hidden_rev = last_hidden.slice::<2, _>(1);
let last_hidden_rev = last_hidden.slice_with(1);
assert_eq!(hidden_seq_rev, last_hidden_rev);
}
}
Expand Down

0 comments on commit a57f594

Please sign in to comment.