Skip to content

Commit

Permalink
Merge pull request #357 from robertknight/generic-conv
Browse files Browse the repository at this point in the history
Make the bulk of the `Conv` implementation generic
  • Loading branch information
robertknight authored Sep 14, 2024
2 parents bbf8501 + 34859a5 commit b1cf831
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 39 deletions.
55 changes: 35 additions & 20 deletions src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView};

use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB};
use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, VirtualMatrix};
use crate::ops::pooling::calc_output_size_and_padding;
use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList, Padding};
use crate::tensor_pool::{AutoReturn, TensorPool};
Expand All @@ -21,12 +21,15 @@ use im2col::VirtualIm2Col;

/// Specialization of conv_2d for pointwise convolutions over one image. This
/// can be reduced to tensor reshaping and matrix multiplication.
fn conv_2d_pointwise(
fn conv_2d_pointwise<X: GemmInT, W: GemmInT, Y: GemmOutT>(
pool: &TensorPool,
input: &NdTensorView<f32, 4>,
kernel: &NdTensorView<f32, 4>,
bias: Option<NdTensorView<f32, 1>>,
) -> Tensor {
input: &NdTensorView<X, 4>,
kernel: &NdTensorView<W, 4>,
bias: Option<NdTensorView<Y, 1>>,
) -> Tensor<Y>
where
GemmExecutor<W, X, Y>: Default,
{
let [batch, _, in_h, in_w]: [usize; 4] = input.shape();
let [out_c, in_c, _, _]: [usize; 4] = kernel.shape();
let mut output = Tensor::uninit_in(pool, &[batch, out_c, in_h * in_w]);
Expand All @@ -40,7 +43,7 @@ fn conv_2d_pointwise(
// Bias must be contiguous for use with `gemm_bias`.
let bias = bias.as_ref().map(|b| b.to_contiguous());

let gemm = GemmExecutor::new();
let gemm = GemmExecutor::<W, X, Y>::default();
let mut n_init = 0;

for n in 0..batch {
Expand Down Expand Up @@ -83,16 +86,23 @@ fn conv_2d_pointwise(
/// A value equal to the input channel count convolves each input channel
/// separately with `output_channels / groups` outputs. This is known as
/// depthwise convolution.
pub fn conv(
pub fn conv<X, W, Y>(
pool: &TensorPool,
input: TensorView,
kernel: TensorView,
bias: Option<TensorView>,
input: TensorView<X>,
kernel: TensorView<W>,
bias: Option<TensorView<Y>>,
padding: Padding,
groups: usize,
strides: &[usize],
dilations: &[usize],
) -> Result<Tensor, OpError> {
) -> Result<Tensor<Y>, OpError>
where
X: std::ops::Mul<W, Output = Y> + GemmInT,
W: GemmInT,
Y: Default + std::ops::AddAssign<Y> + GemmOutT,
GemmExecutor<W, X, Y>: Default,
for<'a> VirtualIm2Col<'a, X>: VirtualMatrix<X>,
{
// Handle 1D convolution by expanding to 2D and then removing the extra
// dimension from the result.
if let &[_n, _c, _w] = input.shape() {
Expand Down Expand Up @@ -211,7 +221,7 @@ pub fn conv(

let n_patches = out_h * out_w;
let mut output = NdTensor::uninit_in(pool, [batch, out_c, n_patches]);
let gemm = GemmExecutor::new();
let gemm = GemmExecutor::<W, X, Y>::default();

// Bias must be contiguous for use with `gemm_bias`.
let bias = bias.map(|b| b.to_contiguous());
Expand Down Expand Up @@ -597,15 +607,20 @@ mod tests {
/// Un-optimized reference implementation of convolution.
///
/// This has the same interface as [conv].
fn reference_conv(
input: TensorView,
kernel: TensorView,
bias: Option<TensorView>,
fn reference_conv<X, W, Y>(
input: TensorView<X>,
kernel: TensorView<W>,
bias: Option<TensorView<Y>>,
padding: Padding,
groups: usize,
strides: &[usize],
dilations: &[usize],
) -> Tensor {
) -> Tensor<Y>
where
X: Copy + std::ops::Mul<W, Output = Y>,
W: Copy,
Y: Copy + Default + std::ops::Add<Y, Output = Y> + std::ops::AddAssign<Y>,
{
// If this is a 1D conv, insert a dummy H axis, perform a 2D convolution
// and then remove the H axis from the result.
if input.ndim() == 3 {
Expand Down Expand Up @@ -666,11 +681,11 @@ mod tests {
let chan_bias = if let Some(ref bias) = bias {
bias[[out_chan]]
} else {
0.0
Y::default()
};
for out_y in 0..out_h {
for out_x in 0..out_w {
let mut accum = 0.0;
let mut accum = Y::default();
for in_chan in in_chan_start..in_chan_end {
for k_y in 0..k_h {
for k_x in 0..k_w {
Expand Down
39 changes: 26 additions & 13 deletions src/ops/conv/depthwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,21 @@ fn min_max_out_x_coords(
///
/// When this function returns, all elements of `output` will have been
/// initialized.
fn conv_2d_depthwise_block(
mut output: NdTensorViewMut<MaybeUninit<f32>, 3>, // C, H, W
fn conv_2d_depthwise_block<X, W, Y>(
mut output: NdTensorViewMut<MaybeUninit<Y>, 3>, // C, H, W
chan_range: Range<usize>,
input: NdTensorView<f32, 3>, // C, H, W
kernel: NdTensorView<f32, 4>, // C, _, Kh, Kw
bias: Option<NdTensorView<f32, 1>>,
input: NdTensorView<X, 3>, // C, H, W
kernel: NdTensorView<W, 4>, // C, _, Kh, Kw
bias: Option<NdTensorView<Y, 1>>,
padding: [usize; 4],
strides: [usize; 2],
dilations: [usize; 2],
col_range_for_kernel_x: &[(Range<usize>, Range<usize>)],
) {
) where
X: Copy + std::ops::Mul<W, Output = Y>,
W: Copy,
Y: Copy + Default + std::ops::AddAssign<Y>,
{
debug_assert_eq!(input.stride(2), 1, "last dim of input is not contiguous");
debug_assert_eq!(output.stride(2), 1, "last dim of output is not contiguous");

Expand All @@ -93,7 +97,11 @@ fn conv_2d_depthwise_block(
let in_row_len = in_chan.size(1);
let in_chan_data = in_chan.data().unwrap();

let init_value = if let Some(bias) = bias { bias[[c]] } else { 0. };
let init_value = if let Some(bias) = bias {
bias[[c]]
} else {
Y::default()
};

// The loops here are ordered so that the inner-most loop is as
// efficient as possible and runs for as long as possible over a
Expand All @@ -105,7 +113,7 @@ fn conv_2d_depthwise_block(
for x in out_row.iter_mut() {
x.write(init_value);
}
let out_row: &mut [f32] = unsafe { std::mem::transmute(out_row) };
let out_row: &mut [Y] = unsafe { std::mem::transmute(out_row) };

for k_y in 0..k_h {
let in_y = out_y * stride_h + k_y * dilation_y;
Expand Down Expand Up @@ -140,16 +148,21 @@ fn conv_2d_depthwise_block(
/// Depthwise convolutions operate over a single input/output channel at
/// a time and hence the transformation of convolution to matrix multiplication
/// doesn't pay off. An optimized direct method works better.
pub fn conv_2d_depthwise(
pub fn conv_2d_depthwise<X, W, Y>(
pool: &TensorPool,
input: &NdTensorView<f32, 4>,
kernel: &NdTensorView<f32, 4>,
bias: Option<NdTensorView<f32, 1>>,
input: &NdTensorView<X, 4>,
kernel: &NdTensorView<W, 4>,
bias: Option<NdTensorView<Y, 1>>,
padding: [usize; 4],
strides: [usize; 2],
dilations: [usize; 2],
out_hw: [usize; 2],
) -> Tensor {
) -> Tensor<Y>
where
X: Copy + std::ops::Mul<W, Output = Y>,
W: Copy,
Y: Copy + Default + std::ops::AddAssign<Y>,
{
let [batch, _in_c, _in_h, in_w]: [usize; 4] = input.shape();
let [out_c, _, _k_h, k_w]: [usize; 4] = kernel.shape();
let [_pad_top, pad_left, _pad_bottom, _pad_right] = padding;
Expand Down
14 changes: 8 additions & 6 deletions src/ops/conv/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ struct ColOffsets {
/// The transform is virtual because the matrix is not actually materialized
/// in memory. Instead blocks of it are produced on-demand during a matrix
/// multiplication operation.
pub struct VirtualIm2Col<'a> {
image: NdTensorView<'a, f32, 3>,
pub struct VirtualIm2Col<'a, T> {
image: NdTensorView<'a, T, 3>,

/// Map of im2col row index to input image coordinate, premultiplied with
/// the corresponding stride.
Expand All @@ -69,18 +69,18 @@ pub struct VirtualIm2Col<'a> {
gemm_kernel: KernelType,
}

impl<'a> VirtualIm2Col<'a> {
impl<'a, T> VirtualIm2Col<'a, T> {
/// Create a virtual im2col matrix from a [C, H, W] input tensor and
/// convolution parameters.
pub fn new(
gemm_kernel: KernelType,
image: NdTensorView<'a, f32, 3>,
image: NdTensorView<'a, T, 3>,
kernel: [usize; 2],
padding: [usize; 4],
strides: [usize; 2],
dilations: [usize; 2],
panel_width: usize,
) -> VirtualIm2Col {
) -> VirtualIm2Col<T> {
// Ensure image has at least one cell.
assert!(image.len() > 0);

Expand Down Expand Up @@ -179,7 +179,9 @@ impl<'a> VirtualIm2Col<'a> {
max_x_offset,
}
}
}

impl<'a> VirtualIm2Col<'a, f32> {
/// Pack part of an image according to the requirements of
/// [VirtualMatrix::pack_b].
///
Expand Down Expand Up @@ -310,7 +312,7 @@ const KERNEL_FMA_NR: usize = 16;
const KERNEL_WASM_NR: usize = 8;

// Safety: `pack_b` initializes the entire buffer passed to it.
unsafe impl<'a> VirtualMatrix<f32> for VirtualIm2Col<'a> {
unsafe impl<'a> VirtualMatrix<f32> for VirtualIm2Col<'a, f32> {
fn rows(&self) -> usize {
self.row_offsets.chan.len()
}
Expand Down

0 comments on commit b1cf831

Please sign in to comment.