Skip to content

Commit

Permalink
Replace old tensor slicing methods with new API
Browse files Browse the repository at this point in the history
Remove the old `TensorBase::{slice, slice_mut, slice_dyn, slice_mut_dyn}`
implementations and replace them with the `slice_with` APIs that infer the
result layout type based on the input layout and slice range.

This now leaves the API with a smaller and easier to use set of methods for
slicing, which automatically infer the output view's layout.
  • Loading branch information
robertknight committed Sep 20, 2024
1 parent 8af2ddf commit 18869f9
Show file tree
Hide file tree
Showing 30 changed files with 145 additions and 367 deletions.
4 changes: 2 additions & 2 deletions rten-examples/src/bert_qa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ fn extract_nbest_answers<'a>(
let min_start = 1; // Ignore [CLS] token at start.
let max_end = end_probs.size(1) - 1; // Ignore [SEP] token at end.
let mut span_scores: Vec<(usize, usize, f32)> = start_probs
.slice::<1, _>((0, min_start..max_end))
.slice((0, min_start..max_end))
.iter()
.enumerate()
.map(|(start_pos, start_score)| {
let start_pos = start_pos + min_start;
let (relative_end_pos, end_score) = end_probs
.slice::<1, _>((0, start_pos..(start_pos + max_answer_len).min(max_end)))
.slice((0, start_pos..(start_pos + max_answer_len).min(max_end)))
.iter()
.enumerate()
.max_by(|(_pos_a, score_a), (_pos_b, score_b)| score_a.total_cmp(score_b))
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/deeplab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ fn main() -> Result<(), Box<dyn Error>> {
output.permute(&[0, 2, 3, 1]); // (N,class,H,W) => (N,H,W,class)

let seg_classes: NdTensor<i32, 2> = output
.slice_dyn(0)
.slice(0)
.arg_max(-1, false /* keep_dims */)?
.try_into()?;
let [out_height, out_width] = seg_classes.shape();
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/depth_anything.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn main() -> Result<(), Box<dyn Error>> {

// Resize output map back to original input size and write to file.
let resized = output.resize_image([orig_height, orig_width])?;
let resized = resized.slice::<3, _>(0);
let resized = resized.nd_view::<4>().slice(0);
write_image(&args.output, resized)?;

Ok(())
Expand Down
16 changes: 8 additions & 8 deletions rten-examples/src/jina_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,18 @@ fn embed_sentence_batch(
let token_ids = encoded.token_ids();
for (tid, input_id) in token_ids
.iter()
.zip(input_ids.slice_mut_dyn((i, ..token_ids.len())).iter_mut())
.zip(input_ids.slice_mut((i, ..token_ids.len())).iter_mut())
{
*input_id = *tid as i32;
}
}

// Generate attention mask, set to 1 for non-padding tokens and 0 for
// padding tokens.
let mut attention_mask = Tensor::zeros(&[batch, max_sequence_len]);
let mut attention_mask = NdTensor::zeros([batch, max_sequence_len]);
for (i, encoded) in encoded.iter().enumerate() {
attention_mask
.slice_mut::<1, _>((i, ..encoded.token_ids().len()))
.slice_mut((i, ..encoded.token_ids().len()))
.fill(1i32);
}

Expand All @@ -127,9 +127,9 @@ fn embed_sentence_batch(

// Generate token type IDs if this model needs them. These are all zeros
// since each item has just one sequence.
let type_ids: Tensor<i32>;
let type_ids: NdTensor<i32, 2>;
if let Some(type_ids_id) = model.find_node("token_type_ids") {
type_ids = Tensor::zeros(&[batch, max_sequence_len]);
type_ids = NdTensor::zeros([batch, max_sequence_len]);
inputs.push((type_ids_id, type_ids.view().into()));
}

Expand All @@ -146,7 +146,7 @@ fn embed_sentence_batch(
// Take the mean of the non-padding elements along the sequence
// dimension.
let seq_len = input.token_ids().len();
item.slice_dyn(..seq_len)
item.slice(..seq_len)
.reduce_mean(Some(&[0]), false /* keep_dims */)
.unwrap()
})
Expand Down Expand Up @@ -231,7 +231,7 @@ fn main() -> Result<(), Box<dyn Error>> {

// (1, embed) @ (embed, batch) => (1, batch)
let similarities = embeddings
.slice::<2, _>(..1)
.slice(..1)
.matmul(embeddings.transposed().into())?;

// Sort results by similarity to the query.
Expand All @@ -240,7 +240,7 @@ fn main() -> Result<(), Box<dyn Error>> {
// all be "high" values (close to 1.0). They should be used only for
// comparison with other scores.
let mut scores: Vec<(usize, f32)> = similarities
.slice_dyn(0)
.slice(0)
.iter()
.copied()
.enumerate()
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/piper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ fn main() -> Result<(), Box<dyn Error>> {

// Convert audio samples from float to 16-bit ints and write to output .wav
// file.
let int_samples = audio_float_to_int16(samples.slice::<1, _>((0, 0, 0)), None);
let int_samples = audio_float_to_int16(samples.slice((0, 0, 0)), None);
let wav_file = BufWriter::new(File::create("output.wav")?);

let mut wav_writer = WavWriter::new(
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/rmbg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let bg_color = [0., 1., 0.]; // RGB
fill_mask(
image.view_mut(),
background_mask.slice::<2, _>([0, 0]), // Extract first mask and channel
background_mask.slice([0, 0]), // Extract first mask and channel
bg_color,
);

Expand Down
6 changes: 2 additions & 4 deletions rten-examples/src/segment_anything.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,9 @@ fn main() -> Result<(), Box<dyn Error>> {
// Resize the output mask to match the original image and save to disk.
let pred_masks: NdTensor<f32, 5> = pred_masks.try_into()?;
let [_batch, _point_batch, _mask, mask_h, mask_w] = pred_masks.shape();
let best_mask = pred_masks
.slice::<2, _>((0, 0, 0))
.reshaped([1, 1, mask_h, mask_w]);
let best_mask = pred_masks.slice((0, 0, 0)).reshaped([1, 1, mask_h, mask_w]);
let resized_mask = best_mask.resize_image([image_h, image_w])?;
write_image("segmented.png", resized_mask.slice::<3, _>(0).nd_view())?;
write_image("segmented.png", resized_mask.nd_view::<4>().slice(0))?;

Ok(())
}
2 changes: 1 addition & 1 deletion rten-examples/src/trocr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ fn main() -> Result<(), Box<dyn Error>> {
image.insert_axis(0); // Add batch dim

// From `image_size` in config.json.
let mut image = image.resize_image([384, 384])?;
let mut image: NdTensor<_, 4> = image.resize_image([384, 384])?.try_into()?;

// Values taken from `preprocessor_config.json`.
let mean = [0.5, 0.5, 0.5];
Expand Down
4 changes: 2 additions & 2 deletions rten-examples/src/yolo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ fn main() -> Result<(), Box<dyn Error>> {
let scale_x = image_width as f32 / model_in_w as f32;

// [batch, n_boxes, coord]
let boxes = output.slice::<3, _>((.., ..4, ..)).permuted([0, 2, 1]);
let boxes = output.slice((.., ..4, ..)).permuted([0, 2, 1]);

// [batch, n_classes, n_boxes]. The `n_boxes` coord is last because that
// is what `non_max_suppression` requires.
let scores = output.slice::<3, _>((.., 4.., ..));
let scores = output.slice((.., 4.., ..));

let iou_threshold = 0.3;
let score_threshold = 0.25;
Expand Down
2 changes: 1 addition & 1 deletion rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ impl<'a> Generator<'a> {

// Sample output token.
let logits: NdTensor<f32, 3> = outputs.remove(0).try_into().map_err(wrap_error)?;
let next_id = self.sampler.sample(logits.slice_with((0, -1)));
let next_id = self.sampler.sample(logits.slice((0, -1)));

// Update the self-attention key-value cache.
//
Expand Down
2 changes: 1 addition & 1 deletion rten-generate/src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl Sampler for TopKSampler {
let topk_index = multinomial(&mut self.rng.borrow_mut(), probs.nd_view())
.expect("probs should be non-empty and sum to 1");

let token_id = topk_indices.slice_with(topk_index).item().copied().unwrap();
let token_id = topk_indices.slice(topk_index).item().copied().unwrap();
token_id as TokenId
}
}
Expand Down
8 changes: 4 additions & 4 deletions rten-imageproc/src/drawing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ impl<'a, T: Copy + Default> Painter<'a, T> {
pub fn draw_polygon(&mut self, points: &[Point]) {
for i in 0..3 {
draw_polygon(
self.surface.slice_with_mut([i]),
self.surface.slice_mut([i]),
points,
self.state.stroke[i],
self.state.stroke_width,
Expand Down Expand Up @@ -600,9 +600,9 @@ mod tests {
let expected_g = expected_r.map(|&x| if x == r { g } else { 0 });
let expected_b = expected_r.map(|&x| if x == r { b } else { 0 });

compare_images(img.slice_with([0]), expected_r.view());
compare_images(img.slice_with([1]), expected_g.view());
compare_images(img.slice_with([2]), expected_b.view());
compare_images(img.slice([0]), expected_r.view());
compare_images(img.slice([1]), expected_g.view());
compare_images(img.slice([2]), expected_b.view());
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion rten-imageproc/src/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub fn normalize_image<const C: usize>(

for chan in 0..n_chans {
let inv_std_dev = 1. / std_dev[chan];
img.slice_with_mut(chan)
img.slice_mut(chan)
.apply(|x| (x - mean[chan]) * inv_std_dev);
}
}
6 changes: 3 additions & 3 deletions rten-tensor/src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ pub fn copy_into_slice<'a, T: Clone>(
let mut dest = NdTensorViewMut::from_data(src.shape(), dest);
for i0 in 0..src.size(0) {
for i1 in 0..src.size(1) {
let src = src.slice_with([i0, i1]);
let dest = dest.slice_with_mut([i0, i1]);
let src = src.slice([i0, i1]);
let dest = dest.slice_mut([i0, i1]);
copy_blocked(src, dest);
}
}
Expand Down Expand Up @@ -437,7 +437,7 @@ fn copy_range_into_slice_inner<T: Clone>(
} else {
// Iterate over views of outermost dimension and recurse.
for i0 in ranges[0] {
let src_slice = src.slice_dyn(i0);
let src_slice = src.slice(i0);
let (dest_slice, dest_tail) = dest.split_at_mut(src_slice.len());

copy_range_into_slice_inner(src_slice, dest_slice, &ranges[1..]);
Expand Down
4 changes: 2 additions & 2 deletions rten-tensor/src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ impl<'a, T, L: MutLayout> Iterator for InnerIterDyn<'a, T, L> {
fn next(&mut self) -> Option<Self::Item> {
self.outer_indices.next().map(|idx| {
let slice_items = to_slice_items(&idx);
self.view.slice_with(slice_items.as_slice())
self.view.slice(slice_items.as_slice())
})
}

Expand Down Expand Up @@ -855,7 +855,7 @@ impl<'a, T, L: MutLayout> Iterator for InnerIterDynMut<'a, T, L> {
fn next(&mut self) -> Option<Self::Item> {
self.outer_indices.next().map(|idx| {
let slice_items = to_slice_items(&idx);
let view: TensorViewMut<'_, T> = self.view.slice_mut_dyn(slice_items.as_slice());
let view: TensorViewMut<'_, T> = self.view.slice_mut(slice_items.as_slice());
unsafe {
// Safety: Outer view is non-broadcasting, and we increment the
// outer index each time, so returned views will not overlap.
Expand Down
Loading

0 comments on commit 18869f9

Please sign in to comment.