Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace old tensor slicing methods with new API #367

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rten-examples/src/bert_qa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ fn extract_nbest_answers<'a>(
let min_start = 1; // Ignore [CLS] token at start.
let max_end = end_probs.size(1) - 1; // Ignore [SEP] token at end.
let mut span_scores: Vec<(usize, usize, f32)> = start_probs
.slice::<1, _>((0, min_start..max_end))
.slice((0, min_start..max_end))
.iter()
.enumerate()
.map(|(start_pos, start_score)| {
let start_pos = start_pos + min_start;
let (relative_end_pos, end_score) = end_probs
.slice::<1, _>((0, start_pos..(start_pos + max_answer_len).min(max_end)))
.slice((0, start_pos..(start_pos + max_answer_len).min(max_end)))
.iter()
.enumerate()
.max_by(|(_pos_a, score_a), (_pos_b, score_b)| score_a.total_cmp(score_b))
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/deeplab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ fn main() -> Result<(), Box<dyn Error>> {
output.permute(&[0, 2, 3, 1]); // (N,class,H,W) => (N,H,W,class)

let seg_classes: NdTensor<i32, 2> = output
.slice_dyn(0)
.slice(0)
.arg_max(-1, false /* keep_dims */)?
.try_into()?;
let [out_height, out_width] = seg_classes.shape();
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/depth_anything.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn main() -> Result<(), Box<dyn Error>> {

// Resize output map back to original input size and write to file.
let resized = output.resize_image([orig_height, orig_width])?;
let resized = resized.slice::<3, _>(0);
let resized = resized.nd_view::<4>().slice(0);
write_image(&args.output, resized)?;

Ok(())
Expand Down
16 changes: 8 additions & 8 deletions rten-examples/src/jina_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,18 @@ fn embed_sentence_batch(
let token_ids = encoded.token_ids();
for (tid, input_id) in token_ids
.iter()
.zip(input_ids.slice_mut_dyn((i, ..token_ids.len())).iter_mut())
.zip(input_ids.slice_mut((i, ..token_ids.len())).iter_mut())
{
*input_id = *tid as i32;
}
}

// Generate attention mask, set to 1 for non-padding tokens and 0 for
// padding tokens.
let mut attention_mask = Tensor::zeros(&[batch, max_sequence_len]);
let mut attention_mask = NdTensor::zeros([batch, max_sequence_len]);
for (i, encoded) in encoded.iter().enumerate() {
attention_mask
.slice_mut::<1, _>((i, ..encoded.token_ids().len()))
.slice_mut((i, ..encoded.token_ids().len()))
.fill(1i32);
}

Expand All @@ -127,9 +127,9 @@ fn embed_sentence_batch(

// Generate token type IDs if this model needs them. These are all zeros
// since each item has just one sequence.
let type_ids: Tensor<i32>;
let type_ids: NdTensor<i32, 2>;
if let Some(type_ids_id) = model.find_node("token_type_ids") {
type_ids = Tensor::zeros(&[batch, max_sequence_len]);
type_ids = NdTensor::zeros([batch, max_sequence_len]);
inputs.push((type_ids_id, type_ids.view().into()));
}

Expand All @@ -146,7 +146,7 @@ fn embed_sentence_batch(
// Take the mean of the non-padding elements along the sequence
// dimension.
let seq_len = input.token_ids().len();
item.slice_dyn(..seq_len)
item.slice(..seq_len)
.reduce_mean(Some(&[0]), false /* keep_dims */)
.unwrap()
})
Expand Down Expand Up @@ -231,7 +231,7 @@ fn main() -> Result<(), Box<dyn Error>> {

// (1, embed) @ (embed, batch) => (1, batch)
let similarities = embeddings
.slice::<2, _>(..1)
.slice(..1)
.matmul(embeddings.transposed().into())?;

// Sort results by similarity to the query.
Expand All @@ -240,7 +240,7 @@ fn main() -> Result<(), Box<dyn Error>> {
// all be "high" values (close to 1.0). They should be used only for
// comparison with other scores.
let mut scores: Vec<(usize, f32)> = similarities
.slice_dyn(0)
.slice(0)
.iter()
.copied()
.enumerate()
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/piper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ fn main() -> Result<(), Box<dyn Error>> {

// Convert audio samples from float to 16-bit ints and write to output .wav
// file.
let int_samples = audio_float_to_int16(samples.slice::<1, _>((0, 0, 0)), None);
let int_samples = audio_float_to_int16(samples.slice((0, 0, 0)), None);
let wav_file = BufWriter::new(File::create("output.wav")?);

let mut wav_writer = WavWriter::new(
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/rmbg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let bg_color = [0., 1., 0.]; // RGB
fill_mask(
image.view_mut(),
background_mask.slice::<2, _>([0, 0]), // Extract first mask and channel
background_mask.slice([0, 0]), // Extract first mask and channel
bg_color,
);

Expand Down
6 changes: 2 additions & 4 deletions rten-examples/src/segment_anything.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,9 @@ fn main() -> Result<(), Box<dyn Error>> {
// Resize the output mask to match the original image and save to disk.
let pred_masks: NdTensor<f32, 5> = pred_masks.try_into()?;
let [_batch, _point_batch, _mask, mask_h, mask_w] = pred_masks.shape();
let best_mask = pred_masks
.slice::<2, _>((0, 0, 0))
.reshaped([1, 1, mask_h, mask_w]);
let best_mask = pred_masks.slice((0, 0, 0)).reshaped([1, 1, mask_h, mask_w]);
let resized_mask = best_mask.resize_image([image_h, image_w])?;
write_image("segmented.png", resized_mask.slice::<3, _>(0).nd_view())?;
write_image("segmented.png", resized_mask.nd_view::<4>().slice(0))?;

Ok(())
}
2 changes: 1 addition & 1 deletion rten-examples/src/trocr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ fn main() -> Result<(), Box<dyn Error>> {
image.insert_axis(0); // Add batch dim

// From `image_size` in config.json.
let mut image = image.resize_image([384, 384])?;
let mut image: NdTensor<_, 4> = image.resize_image([384, 384])?.try_into()?;

// Values taken from `preprocessor_config.json`.
let mean = [0.5, 0.5, 0.5];
Expand Down
4 changes: 2 additions & 2 deletions rten-examples/src/yolo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ fn main() -> Result<(), Box<dyn Error>> {
let scale_x = image_width as f32 / model_in_w as f32;

// [batch, n_boxes, coord]
let boxes = output.slice::<3, _>((.., ..4, ..)).permuted([0, 2, 1]);
let boxes = output.slice((.., ..4, ..)).permuted([0, 2, 1]);

// [batch, n_classes, n_boxes]. The `n_boxes` coord is last because that
// is what `non_max_suppression` requires.
let scores = output.slice::<3, _>((.., 4.., ..));
let scores = output.slice((.., 4.., ..));

let iou_threshold = 0.3;
let score_threshold = 0.25;
Expand Down
2 changes: 1 addition & 1 deletion rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ impl<'a> Generator<'a> {

// Sample output token.
let logits: NdTensor<f32, 3> = outputs.remove(0).try_into().map_err(wrap_error)?;
let next_id = self.sampler.sample(logits.slice_with((0, -1)));
let next_id = self.sampler.sample(logits.slice((0, -1)));

// Update the self-attention key-value cache.
//
Expand Down
2 changes: 1 addition & 1 deletion rten-generate/src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl Sampler for TopKSampler {
let topk_index = multinomial(&mut self.rng.borrow_mut(), probs.nd_view())
.expect("probs should be non-empty and sum to 1");

let token_id = topk_indices.slice_with(topk_index).item().copied().unwrap();
let token_id = topk_indices.slice(topk_index).item().copied().unwrap();
token_id as TokenId
}
}
Expand Down
8 changes: 4 additions & 4 deletions rten-imageproc/src/drawing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ impl<'a, T: Copy + Default> Painter<'a, T> {
pub fn draw_polygon(&mut self, points: &[Point]) {
for i in 0..3 {
draw_polygon(
self.surface.slice_with_mut([i]),
self.surface.slice_mut([i]),
points,
self.state.stroke[i],
self.state.stroke_width,
Expand Down Expand Up @@ -600,9 +600,9 @@ mod tests {
let expected_g = expected_r.map(|&x| if x == r { g } else { 0 });
let expected_b = expected_r.map(|&x| if x == r { b } else { 0 });

compare_images(img.slice_with([0]), expected_r.view());
compare_images(img.slice_with([1]), expected_g.view());
compare_images(img.slice_with([2]), expected_b.view());
compare_images(img.slice([0]), expected_r.view());
compare_images(img.slice([1]), expected_g.view());
compare_images(img.slice([2]), expected_b.view());
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion rten-imageproc/src/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub fn normalize_image<const C: usize>(

for chan in 0..n_chans {
let inv_std_dev = 1. / std_dev[chan];
img.slice_with_mut(chan)
img.slice_mut(chan)
.apply(|x| (x - mean[chan]) * inv_std_dev);
}
}
6 changes: 3 additions & 3 deletions rten-tensor/src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ pub fn copy_into_slice<'a, T: Clone>(
let mut dest = NdTensorViewMut::from_data(src.shape(), dest);
for i0 in 0..src.size(0) {
for i1 in 0..src.size(1) {
let src = src.slice_with([i0, i1]);
let dest = dest.slice_with_mut([i0, i1]);
let src = src.slice([i0, i1]);
let dest = dest.slice_mut([i0, i1]);
copy_blocked(src, dest);
}
}
Expand Down Expand Up @@ -437,7 +437,7 @@ fn copy_range_into_slice_inner<T: Clone>(
} else {
// Iterate over views of outermost dimension and recurse.
for i0 in ranges[0] {
let src_slice = src.slice_dyn(i0);
let src_slice = src.slice(i0);
let (dest_slice, dest_tail) = dest.split_at_mut(src_slice.len());

copy_range_into_slice_inner(src_slice, dest_slice, &ranges[1..]);
Expand Down
4 changes: 2 additions & 2 deletions rten-tensor/src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ impl<'a, T, L: MutLayout> Iterator for InnerIterDyn<'a, T, L> {
fn next(&mut self) -> Option<Self::Item> {
self.outer_indices.next().map(|idx| {
let slice_items = to_slice_items(&idx);
self.view.slice_with(slice_items.as_slice())
self.view.slice(slice_items.as_slice())
})
}

Expand Down Expand Up @@ -855,7 +855,7 @@ impl<'a, T, L: MutLayout> Iterator for InnerIterDynMut<'a, T, L> {
fn next(&mut self) -> Option<Self::Item> {
self.outer_indices.next().map(|idx| {
let slice_items = to_slice_items(&idx);
let view: TensorViewMut<'_, T> = self.view.slice_mut_dyn(slice_items.as_slice());
let view: TensorViewMut<'_, T> = self.view.slice_mut(slice_items.as_slice());
unsafe {
// Safety: Outer view is non-broadcasting, and we increment the
// outer index each time, so returned views will not overlap.
Expand Down
Loading
Loading