Skip to content

Commit

Permalink
Use TensorBase::slice_with in rten crates
Browse files Browse the repository at this point in the history
Using `slice_with` ensures at compile time that the returned view has the
correct rank. Since it requires the input tensor to have a static rank to infer
the outut rank, it also encourages the use of static rank tensors, which is
generally good practice.
  • Loading branch information
robertknight committed Sep 17, 2024
1 parent f5a24f6 commit 1ca14e0
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 50 deletions.
2 changes: 1 addition & 1 deletion rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ impl<'a> Generator<'a> {

// Sample output token.
let logits: NdTensor<f32, 3> = outputs.remove(0).try_into().map_err(wrap_error)?;
let next_id = self.sampler.sample(logits.slice::<1, _>((0, -1)));
let next_id = self.sampler.sample(logits.slice_with((0, -1)));

// Update the self-attention key-value cache.
//
Expand Down
6 changes: 1 addition & 5 deletions rten-generate/src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,7 @@ impl Sampler for TopKSampler {
let topk_index = multinomial(&mut self.rng.borrow_mut(), probs.nd_view())
.expect("probs should be non-empty and sum to 1");

let token_id = topk_indices
.slice::<0, _>(topk_index)
.item()
.copied()
.unwrap();
let token_id = topk_indices.slice_with(topk_index).item().copied().unwrap();
token_id as TokenId
}
}
Expand Down
4 changes: 2 additions & 2 deletions rten-tensor/src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ pub fn copy_into_slice<'a, T: Clone>(
let mut dest = NdTensorViewMut::from_data(src.shape(), dest);
for i0 in 0..src.size(0) {
for i1 in 0..src.size(1) {
let src = src.slice::<2, _>([i0, i1]);
let dest = dest.slice_mut::<2, _>([i0, i1]);
let src = src.slice_with([i0, i1]);
let dest = dest.slice_with_mut([i0, i1]);
copy_blocked(src, dest);
}
}
Expand Down
4 changes: 2 additions & 2 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3657,8 +3657,8 @@ mod tests {
#[test]
fn test_to_array() {
let tensor = NdTensor::arange(1., 5., None).into_shape([2, 2]);
let col0: [f32; 2] = tensor.view().transposed().slice::<1, _>(0).to_array();
let col1: [f32; 2] = tensor.view().transposed().slice::<1, _>(1).to_array();
let col0: [f32; 2] = tensor.view().transposed().slice_with(0).to_array();
let col1: [f32; 2] = tensor.view().transposed().slice_with(1).to_array();
assert_eq!(col0, [1., 3.]);
assert_eq!(col1, [2., 4.]);
}
Expand Down
4 changes: 2 additions & 2 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ fn gemv<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
for (k_block, a_block) in
range_chunks(0..a_cols, k_block_size).zip(a_data.chunks(k_block_size))
{
let b_block = b.slice::<2, _>((k_block, col_block.clone()));
let b_block = b.slice_with((k_block, col_block.clone()));
kernel.gemv_kernel(out_chunk, a_block, b_block, alpha, effective_beta);

// Reset `beta` so that subsequent updates for each column
Expand Down Expand Up @@ -815,7 +815,7 @@ fn gemm_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
if let (1, GemmInputA::Unpacked(a), GemmInputB::Unpacked(b)) = (a.rows(), a, b) {
gemv(
kernel,
a.slice::<1, _>(0),
a.slice_with(0),
b,
output_mat.view_mut(),
alpha,
Expand Down
2 changes: 1 addition & 1 deletion src/gemm/kernels/simd_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ unsafe fn simd_gemv_transposed<S: SimdFloat>(
simd_gemv_fallback(
&mut out[last_col_tile.clone()],
a,
b.slice::<2, _>((.., last_col_tile)),
b.slice_with((.., last_col_tile)),
alpha,
beta,
);
Expand Down
25 changes: 15 additions & 10 deletions src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use rayon::prelude::*;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView};

use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, VirtualMatrix};
use crate::ops::pooling::calc_output_size_and_padding;
use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList, Padding};
use crate::tensor_pool::{AutoReturn, TensorPool};
use crate::{check_dims, static_dims};

mod depthwise;
mod im2col;
Expand Down Expand Up @@ -50,7 +50,7 @@ where
let mut out_item = output.slice_mut::<2, _>([n]);
let out_row_stride = out_item.stride(0);

let in_mat = input.slice::<3, _>([n]).reshaped([in_c, in_h * in_w]);
let in_mat = input.slice_with([n]).reshaped([in_c, in_h * in_w]);

gemm.gemm_uninit_bias(
out_item.data_mut().unwrap(),
Expand Down Expand Up @@ -148,8 +148,11 @@ where
});
}

let [batch, in_c, in_h, in_w] = check_dims!(input, 4, "NCHW");
let [out_c, k_in_c, k_h, k_w] = check_dims!(kernel, 4, "OCHW");
let input = static_dims!(input, 4, "NCHW")?;
let [batch, in_c, in_h, in_w] = input.shape();

let kernel = static_dims!(kernel, 4, "OCHW")?;
let [out_c, k_in_c, k_h, k_w] = kernel.shape();
check_dims!(bias?, 1);

let input = input.view();
Expand Down Expand Up @@ -235,12 +238,12 @@ where
let out_chan_start = group * out_channels_per_group;
let out_chans = out_chan_start..out_chan_start + out_channels_per_group;

let in_group = input.slice::<4, _>((.., in_chan_start..in_chan_end));
let mut out_group = output.slice_mut::<3, _>((.., out_chans.clone()));
let in_group = input.slice_with((.., in_chan_start..in_chan_end));
let mut out_group = output.slice_with_mut((.., out_chans.clone()));

let kernel = kernel.to_contiguous_in(pool);
let kernel_mat = kernel
.slice::<4, _>([out_chans.clone()])
.slice_with([out_chans.clone()])
.reshaped([out_channels_per_group, in_channels_per_group * k_h * k_w]);

// Prepack kernel if we'll be able to reuse packed weights.
Expand Down Expand Up @@ -491,8 +494,10 @@ pub fn conv_transpose(
});
}

let [batch, in_c, in_h, in_w] = check_dims!(input, 4, "NCHW");
let [k_in_c, out_c, k_h, k_w] = check_dims!(kernel, 4, "OCHW");
let input = static_dims!(input, 4, "NCHW")?;
let [batch, in_c, in_h, in_w] = input.shape();
let kernel = static_dims!(kernel, 4, "OCHW")?;
let [k_in_c, out_c, k_h, k_w] = kernel.shape();
check_dims!(bias?, 1);

let bias = bias.map(|b| b.nd_view());
Expand Down Expand Up @@ -530,7 +535,7 @@ pub fn conv_transpose(
// The implementation here is the inverse of the im2col-based convolution.
let mut n_init = 0;
for n in 0..batch {
let input_mat = input.slice::<3, _>([n]).reshaped([in_c, in_h * in_w]);
let input_mat = input.slice_with([n]).reshaped([in_c, in_h * in_w]);

let col2im_row_stride = col2im_mat.stride(0);
gemm.gemm_uninit(
Expand Down
6 changes: 3 additions & 3 deletions src/ops/conv/depthwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn conv_2d_depthwise_block<X, W, Y>(
let out_row_len = out_chan.size(1);
let out_chan_data = out_chan.data_mut().unwrap();

let in_chan = input.slice::<2, _>([c]);
let in_chan = input.slice_with([c]);
let in_row_stride = in_chan.stride(0);
let in_row_len = in_chan.size(1);
let in_chan_data = in_chan.data().unwrap();
Expand Down Expand Up @@ -204,8 +204,8 @@ where

let n_init = AtomicUsize::new(0);
for n in 0..batch {
let mut out_chans = output.slice_mut::<3, _>(n);
let input = input.slice::<3, _>(n);
let mut out_chans = output.slice_with_mut(n);
let input = input.slice_with(n);

out_chans
.axis_chunks_mut(0, channel_chunk_size)
Expand Down
2 changes: 1 addition & 1 deletion src/ops/non_max_suppression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub fn non_max_suppression(
for n in 0..batch {
for b in 0..n_boxes {
let (max_score_cls, max_score) = scores
.slice::<1, _>((n, .., b))
.slice_with((n, .., b))
.iter()
.copied()
.enumerate()
Expand Down
13 changes: 6 additions & 7 deletions src/ops/pooling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView, TensorViewMut};

use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList, Padding};
use crate::static_dims;
use crate::tensor_pool::TensorPool;
use crate::{check_dims, static_dims};

/// Calculate the output size and padding for a convolution or pooling operation.
///
Expand Down Expand Up @@ -292,7 +292,8 @@ impl Operator for AveragePool {
}

pub fn global_average_pool(pool: &TensorPool, input: TensorView) -> Result<Tensor, OpError> {
let [batch, chans, in_h, in_w] = check_dims!(input, 4, "NCHW");
let input = static_dims!(input, 4, "NCHW")?;
let [batch, chans, in_h, in_w] = input.shape();

let mut output = NdTensor::uninit_in(pool, [batch, chans, 1, 1]);
let mut n_init = 0;
Expand All @@ -301,10 +302,8 @@ pub fn global_average_pool(pool: &TensorPool, input: TensorView) -> Result<Tenso
const N: usize = 4;

for (chan_group, mut out_group) in zip(
input.slice::<3, _>(n).axis_chunks(0, N),
output
.slice_mut::<1, _>((n, .., 0, 0))
.axis_chunks_mut(0, N),
input.slice_with(n).axis_chunks(0, N),
output.slice_with_mut((n, .., 0, 0)).axis_chunks_mut(0, N),
) {
if chan_group.size(0) == N {
// Compute average over batch of N channels in parallel.
Expand All @@ -327,7 +326,7 @@ pub fn global_average_pool(pool: &TensorPool, input: TensorView) -> Result<Tenso
} else {
// Compute average over remaining channels.
for i in 0..chan_group.size(0) {
let sum: f32 = chan_group.slice::<2, _>([i]).iter().sum();
let sum: f32 = chan_group.slice_with([i]).iter().sum();
out_group[[i]].write(sum / (in_h * in_w) as f32);
n_init += 1;
}
Expand Down
8 changes: 5 additions & 3 deletions src/ops/resize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ pub fn resize(

// The current implementation only supports NCHW tensors with scale factors
// other than 1.0 for the H and W dims.
let [batch, _chans, _height, _width] = check_dims!(input, 4, "NCHW");
let input = static_dims!(input, 4, "NCHW")?;
let [batch, _chans, _height, _width] = input.shape();
let sizes_valid = zip(0..input.ndim(), input.shape().iter()).all(|(dim, &in_size)| {
dim == input.ndim() - 1 || dim == input.ndim() - 2 || sizes[[dim]] == in_size as i32
});
Expand All @@ -297,8 +298,9 @@ pub fn resize(

let n_init = AtomicUsize::new(0);
for n in 0..batch {
let in_image = input.slice::<3, _>([n]);
let mut out_image = output.slice_mut::<3, _>([n]);
let in_image = input.slice_with([n]);
let mut out_batch = output.nd_view_mut::<4>();
let mut out_image = out_batch.slice_with_mut([n]);

out_image
.axis_chunks_mut(0, CHAN_GROUP_SIZE)
Expand Down
30 changes: 17 additions & 13 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ use std::iter::{zip, Rev};
use std::ops::Range;

use rten_tensor::prelude::*;
use rten_tensor::{Tensor, TensorView};
use rten_tensor::{NdTensor, Tensor, TensorView};

use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB};
use crate::ops::{
add_in_place, mul_in_place, sigmoid, tanh, InputList, IntoOpResult, OpError, Operator,
OutputList,
};
use crate::tensor_pool::{AutoReturn, TensorPool};
use crate::{check_dims, static_dims};

/// Direction that an RNN operator will traverse the input sequence in.
#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -141,18 +141,22 @@ pub fn gru(
));
}

let [seq_len, batch, _input_size] = check_dims!(input, 3, "seq, batch, input");
let [_directions, hidden_x3, _input_size] = check_dims!(weights, 3, "dir, hidden x 3, input");
check_dims!(recurrent_weights, 3);
check_dims!(initial_hidden?, 3);
let input = static_dims!(input, 3, "seq, batch, input")?;
let weights = static_dims!(weights, 3, "dir, hidden x 3, input")?;
let recurrent_weights = static_dims!(recurrent_weights, 3)?;

let [seq_len, batch, _input_size] = input.shape();
let [_directions, hidden_x3, _input_size] = weights.shape();

let initial_hidden = initial_hidden.map(|ih| static_dims!(ih, 3)).transpose()?;

let num_directions = direction.num_directions();
let hidden_size = hidden_x3 / 3;

let mut hidden = initial_hidden
.map(|t| t.to_tensor_in(pool))
.unwrap_or_else(|| Tensor::zeros_in(pool, &[num_directions, batch, hidden_size]));
let mut hidden_seq = Tensor::zeros_in(pool, &[seq_len, num_directions, batch, hidden_size]);
.unwrap_or_else(|| NdTensor::zeros_in(pool, [num_directions, batch, hidden_size]));
let mut hidden_seq = NdTensor::zeros_in(pool, [seq_len, num_directions, batch, hidden_size]);

// Indices of gates in the concatenated weight and bias tensors.
const UPDATE_GATE: usize = 0;
Expand All @@ -171,15 +175,15 @@ pub fn gru(
for dir in 0..num_directions {
let prepack = seq_len >= PREPACK_MIN_SEQ_LEN;

let input_weights = weights.slice::<2, _>(dir).transposed();
let input_weights = weights.slice_with(dir).transposed();
let packed_input_weights =
prepack.then(|| gemm.prepack_b_in(pool, input_weights).auto_return(pool));
let input_weights = packed_input_weights
.as_ref()
.map(|packed| GemmInputB::Packed(packed))
.unwrap_or(GemmInputB::Unpacked(input_weights));

let hidden_weights = recurrent_weights.slice::<2, _>(dir).transposed();
let hidden_weights = recurrent_weights.slice_with(dir).transposed();
let packed_hidden_weights =
prepack.then(|| gemm.prepack_b_in(pool, hidden_weights).auto_return(pool));
let hidden_weights = packed_hidden_weights
Expand All @@ -195,8 +199,8 @@ pub fn gru(
.map(|b| b.slice::<1, _>((dir, (n_gates * hidden_size)..)));

for seq in sequence_for_dir(direction, dir, seq_len) {
let in_item = input.slice::<2, _>([seq]);
let hidden_item = hidden.slice::<2, _>([dir]);
let in_item = input.slice_with([seq]);
let hidden_item = hidden.slice_with([dir]);

// From the ONNX spec, the intermediate values are computed as:
//
Expand Down Expand Up @@ -305,7 +309,7 @@ pub fn gru(
}
}

Ok([hidden_seq, hidden].into())
Ok([hidden_seq.into_dyn(), hidden.into_dyn()].into())
}

impl Operator for GRU {
Expand Down

0 comments on commit 1ca14e0

Please sign in to comment.