Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace most remaining uses of slice with slice_with #362

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions rten-imageproc/src/drawing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ impl<'a, T: Copy + Default> Painter<'a, T> {
pub fn draw_polygon(&mut self, points: &[Point]) {
for i in 0..3 {
draw_polygon(
self.surface.slice_mut([i]),
self.surface.slice_with_mut([i]),
points,
self.state.stroke[i],
self.state.stroke_width,
Expand Down Expand Up @@ -600,9 +600,9 @@ mod tests {
let expected_g = expected_r.map(|&x| if x == r { g } else { 0 });
let expected_b = expected_r.map(|&x| if x == r { b } else { 0 });

compare_images(img.slice([0]), expected_r.view());
compare_images(img.slice([1]), expected_g.view());
compare_images(img.slice([2]), expected_b.view());
compare_images(img.slice_with([0]), expected_r.view());
compare_images(img.slice_with([1]), expected_g.view());
compare_images(img.slice_with([2]), expected_b.view());
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion rten-imageproc/src/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub fn normalize_image<const C: usize>(

for chan in 0..n_chans {
let inv_std_dev = 1. / std_dev[chan];
img.slice_mut::<2, _>(chan)
img.slice_with_mut(chan)
.apply(|x| (x - mean[chan]) * inv_std_dev);
}
}
4 changes: 2 additions & 2 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2434,8 +2434,8 @@ mod tests {
let mut transposed = tensor.view_mut();

transposed.permute([1, 0]);
transposed.slice_mut(0).assign_array([1, 2]);
transposed.slice_mut(1).assign_array([3, 4]);
transposed.slice_with_mut(0).assign_array([1, 2]);
transposed.slice_with_mut(1).assign_array([3, 4]);

assert_eq!(tensor.iter().copied().collect::<Vec<_>>(), [1, 3, 2, 4]);
}
Expand Down
20 changes: 10 additions & 10 deletions src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ where
{
let [batch, _, in_h, in_w]: [usize; 4] = input.shape();
let [out_c, in_c, _, _]: [usize; 4] = kernel.shape();
let mut output = Tensor::uninit_in(pool, &[batch, out_c, in_h * in_w]);
let mut output = NdTensor::uninit_in(pool, [batch, out_c, in_h * in_w]);

// Get input and kernel as contiguous tensors so we can create reshaped
// views.
Expand All @@ -47,7 +47,7 @@ where
let mut n_init = 0;

for n in 0..batch {
let mut out_item = output.slice_mut::<2, _>([n]);
let mut out_item = output.slice_with_mut([n]);
let out_row_stride = out_item.stride(0);

let in_mat = input.slice_with([n]).reshaped([in_c, in_h * in_w]);
Expand All @@ -63,11 +63,11 @@ where
n_init += out_item.len();
}

output.reshape(&[batch, out_c, in_h, in_w]);
let output = output.into_shape([batch, out_c, in_h, in_w]);

// Safety: We used `gemm_uninit_bias` to initialize all elements.
assert!(n_init == output.len());
unsafe { output.assume_init() }
unsafe { output.assume_init().into_dyn() }
}

/// Perform a convolution of `input` with `kernel`.
Expand Down Expand Up @@ -355,15 +355,15 @@ fn col2im(

for out_c in 0..out_chans {
// Initialize each output channel just before we accumulate into it.
let mut out_img = output.slice_mut([out_c]);
let mut out_img = output.slice_with_mut([out_c]);
out_img.fill(MaybeUninit::new(bias.map(|b| b[[out_c]]).unwrap_or(0.)));

// Safety: We just initialized all elements of `out_img`.
let mut out_img = unsafe { out_img.assume_init() };

for k_y in 0..kernel_h {
for k_x in 0..kernel_w {
let in_img = columns.slice([out_c, k_y, k_x]);
let in_img = columns.slice_with([out_c, k_y, k_x]);
let [img_h, img_w] = in_img.shape();

for y in 0..img_h {
Expand Down Expand Up @@ -521,7 +521,7 @@ pub fn conv_transpose(
let [out_h, out_w] = out_shape;
let [pad_top, pad_left, pad_bottom, pad_right] = fixed_padding;

let mut output = Tensor::uninit_in(pool, [batch, out_c, out_h, out_w].as_slice());
let mut output = NdTensor::uninit_in(pool, [batch, out_c, out_h, out_w]);

// Ensure input and kernel are contiguous to support reshaping.
let input = input.to_contiguous_in(pool).auto_return(pool);
Expand All @@ -548,7 +548,7 @@ pub fn conv_transpose(

// Safety: `gemm_uninit` initialized col2im_mat.
let col2im_mat = unsafe { col2im_mat.view().assume_init() };
let mut out_img = output.slice_mut(n);
let mut out_img = output.slice_with_mut(n);

col2im(
&mut out_img,
Expand All @@ -562,7 +562,7 @@ pub fn conv_transpose(

assert!(n_init == output.len());
let output = unsafe { output.assume_init() };
Ok(output)
Ok(output.into_dyn())
}

#[derive(Debug)]
Expand Down Expand Up @@ -1634,7 +1634,7 @@ mod tests {
// With padding.
run_bench(100, Some("col2im"), || {
col2im(
&mut output.slice_mut((.., 2.., 2..)),
&mut output.slice_with_mut((.., 2.., 2..)),
&columns.view(),
[1, 1, 1, 1], // Padding
[stride_y, stride_x],
Expand Down
4 changes: 2 additions & 2 deletions src/ops/conv/depthwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ fn conv_2d_depthwise_block<X, W, Y>(
let [dilation_y, _dilation_x] = dilations;

for c in chan_range.clone() {
let kernel_view = kernel.slice([c, 0]).weakly_checked_view();
let kernel_view = kernel.slice_with([c, 0]).weakly_checked_view();

// For efficiency, use manual slicing in the inner loops to extract
// input/output rows.
let mut out_chan = output.slice_mut::<2, _>([c - chan_range.start]);
let mut out_chan = output.slice_with_mut([c - chan_range.start]);
let out_row_stride = out_chan.stride(0);
let out_row_len = out_chan.size(1);
let out_chan_data = out_chan.data_mut().unwrap();
Expand Down
20 changes: 10 additions & 10 deletions src/ops/non_max_suppression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub fn non_max_suppression(
continue;
}

let [c0, c1, c2, c3] = boxes.slice((n, b)).to_array();
let [c0, c1, c2, c3] = boxes.slice_with((n, b)).to_array();
let [top, left, bottom, right] = match box_order {
BoxOrder::TopLeftBottomRight => [c0, c1, c2, c3],
BoxOrder::CenterWidthHeight => {
Expand Down Expand Up @@ -172,7 +172,7 @@ pub fn non_max_suppression(

let mut selected_indices = NdTensor::zeros_in(pool, [selected.len(), 3]);
for (i, nms_box) in selected.into_iter().enumerate() {
selected_indices.slice_mut(i).assign_array([
selected_indices.slice_with_mut(i).assign_array([
nms_box.batch_index as i32,
nms_box.class as i32,
nms_box.box_index as i32,
Expand Down Expand Up @@ -255,7 +255,7 @@ mod tests {
[cx, cy, w, h]
}
};
out_boxes.slice_mut((0, i)).assign_array(coords);
out_boxes.slice_with_mut((0, i)).assign_array(coords);
out_scores[[0, nms_box.class, i]] = nms_box.score;
}

Expand Down Expand Up @@ -309,10 +309,10 @@ mod tests {

assert_eq!(selected.size(0), 2);

let [batch, class, box_idx] = selected.slice(0).to_array();
let [batch, class, box_idx] = selected.slice_with(0).to_array();
assert_eq!([batch, class, box_idx], [0, 0, 0]);

let [batch, class, box_idx] = selected.slice(1).to_array();
let [batch, class, box_idx] = selected.slice_with(1).to_array();
assert_eq!([batch, class, box_idx], [0, 1, 2]);
}

Expand Down Expand Up @@ -371,10 +371,10 @@ mod tests {
// returned.
assert_eq!(selected.size(0), 3);

let [batch, class, box_idx] = selected.slice(0).to_array();
let [batch, class, box_idx] = selected.slice_with(0).to_array();
assert_eq!([batch, class, box_idx], [0, 0, 0]);

let [batch, class, box_idx] = selected.slice(1).to_array();
let [batch, class, box_idx] = selected.slice_with(1).to_array();
assert_eq!([batch, class, box_idx], [0, 0, 1]);
}

Expand All @@ -400,10 +400,10 @@ mod tests {
// be returned from each class.
assert!(selected.size(0) == 2);

let [batch, class, box_idx] = selected.slice(0).to_array();
let [batch, class, box_idx] = selected.slice_with(0).to_array();
assert_eq!([batch, class, box_idx], [0, 0, 0]);

let [batch, class, box_idx] = selected.slice(1).to_array();
let [batch, class, box_idx] = selected.slice_with(1).to_array();
assert_eq!([batch, class, box_idx], [0, 1, 2]);
}

Expand All @@ -428,7 +428,7 @@ mod tests {
// Only the box with score exceeding `score_threshold` will be returned.
assert!(selected.size(0) == 1);

let [batch, class, box_idx] = selected.slice(0).to_array();
let [batch, class, box_idx] = selected.slice_with(0).to_array();
assert_eq!([batch, class, box_idx], [0, 0, 0]);
}

Expand Down
51 changes: 26 additions & 25 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ mod tests {
use rten_tensor::prelude::*;
use rten_tensor::rng::XorShiftRng;
use rten_tensor::test_util::expect_equal;
use rten_tensor::Tensor;
use rten_tensor::{NdTensor, Tensor};
use serde_json::Value;

use crate::ops::tests::new_pool;
Expand Down Expand Up @@ -695,45 +695,46 @@ mod tests {
Op::Lstm => 4,
};

let input = Tensor::<f32>::rand(&[seq_len, batch, features], &mut rng).map(|x| x - 0.5);
let weights = Tensor::<f32>::rand(
&[dir.num_directions(), num_gates * hidden_size, features],
let input =
NdTensor::<f32, 3>::rand([seq_len, batch, features], &mut rng).map(|x| x - 0.5);
let weights = NdTensor::<f32, 3>::rand(
[dir.num_directions(), num_gates * hidden_size, features],
&mut rng,
)
.map(|x| x - 0.5);
let recurrent_weights = Tensor::<f32>::rand(
&[dir.num_directions(), num_gates * hidden_size, hidden_size],
let recurrent_weights = NdTensor::<f32, 3>::rand(
[dir.num_directions(), num_gates * hidden_size, hidden_size],
&mut rng,
)
.map(|x| x - 0.5);
let bias = Tensor::rand(
&[dir.num_directions(), 2 * num_gates * hidden_size],
let bias = NdTensor::rand(
[dir.num_directions(), 2 * num_gates * hidden_size],
&mut rng,
);
let initial_hidden =
Tensor::rand(&[dir.num_directions(), batch, hidden_size], &mut rng);
let initial_cell = Tensor::rand(&[dir.num_directions(), batch, hidden_size], &mut rng);
NdTensor::rand([dir.num_directions(), batch, hidden_size], &mut rng);
let initial_cell = NdTensor::rand([dir.num_directions(), batch, hidden_size], &mut rng);

let result = match case.op {
Op::Lstm => lstm(
&pool,
dir,
input.view(),
weights.view(),
recurrent_weights.view(),
case.with_bias.then_some(bias.view()),
case.with_hidden_init.then_some(initial_hidden.view()),
case.with_initial_cell.then_some(initial_cell.view()),
input.as_dyn(),
weights.as_dyn(),
recurrent_weights.as_dyn(),
case.with_bias.then_some(bias.as_dyn()),
case.with_hidden_init.then_some(initial_hidden.as_dyn()),
case.with_initial_cell.then_some(initial_cell.as_dyn()),
)
.expect("lstm op failed"),
Op::Gru => gru(
&pool,
dir,
input.view(),
weights.view(),
recurrent_weights.view(),
case.with_bias.then_some(bias.view()),
case.with_hidden_init.then_some(initial_hidden.view()),
input.as_dyn(),
weights.as_dyn(),
recurrent_weights.as_dyn(),
case.with_bias.then_some(bias.as_dyn()),
case.with_hidden_init.then_some(initial_hidden.as_dyn()),
true, /* linear_before_reset */
)
.expect("gru op failed"),
Expand Down Expand Up @@ -770,18 +771,18 @@ mod tests {
// The last hidden state should match the end of the hidden sequence
// for the forwards direction, and the start of the hidden sequence
// for the reverse direction.
let hidden_seq_fwd = hidden_seq.slice::<2, _>((
let hidden_seq_fwd = hidden_seq.slice_with((
-1, // seq
0, // direction
));
let last_hidden_fwd = last_hidden.slice::<2, _>(0);
let last_hidden_fwd = last_hidden.slice_with(0);
assert_eq!(hidden_seq_fwd, last_hidden_fwd);

let hidden_seq_rev = hidden_seq.slice::<2, _>((
let hidden_seq_rev = hidden_seq.slice_with((
0, // seq
1, // direction
));
let last_hidden_rev = last_hidden.slice::<2, _>(1);
let last_hidden_rev = last_hidden.slice_with(1);
assert_eq!(hidden_seq_rev, last_hidden_rev);
}
}
Expand Down
Loading