From 34859a50adc0214bec8ab5712976f41ac71fc0b7 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sat, 14 Sep 2024 12:45:03 -0400 Subject: [PATCH] Make the bulk of the `Conv` implementation generic This is a step towards supporting other float types and integer types in convolution operations. --- src/ops/conv.rs | 55 +++++++++++++++++++++++++-------------- src/ops/conv/depthwise.rs | 39 ++++++++++++++++++--------- src/ops/conv/im2col.rs | 14 +++++----- 3 files changed, 69 insertions(+), 39 deletions(-) diff --git a/src/ops/conv.rs b/src/ops/conv.rs index 81eac754..46872a8c 100644 --- a/src/ops/conv.rs +++ b/src/ops/conv.rs @@ -8,7 +8,7 @@ use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView}; use crate::check_dims; -use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB}; +use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, VirtualMatrix}; use crate::ops::pooling::calc_output_size_and_padding; use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList, Padding}; use crate::tensor_pool::{AutoReturn, TensorPool}; @@ -21,12 +21,15 @@ use im2col::VirtualIm2Col; /// Specialization of conv_2d for pointwise convolutions over one image. This /// can be reduced to tensor reshaping and matrix multiplication. -fn conv_2d_pointwise( +fn conv_2d_pointwise( pool: &TensorPool, - input: &NdTensorView, - kernel: &NdTensorView, - bias: Option>, -) -> Tensor { + input: &NdTensorView, + kernel: &NdTensorView, + bias: Option>, +) -> Tensor +where + GemmExecutor: Default, +{ let [batch, _, in_h, in_w]: [usize; 4] = input.shape(); let [out_c, in_c, _, _]: [usize; 4] = kernel.shape(); let mut output = Tensor::uninit_in(pool, &[batch, out_c, in_h * in_w]); @@ -40,7 +43,7 @@ fn conv_2d_pointwise( // Bias must be contiguous for use with `gemm_bias`. let bias = bias.as_ref().map(|b| b.to_contiguous()); - let gemm = GemmExecutor::new(); + let gemm = GemmExecutor::::default(); let mut n_init = 0; for n in 0..batch { @@ -83,16 +86,23 @@ fn conv_2d_pointwise( /// A value equal to the input channel count convolves each input channel /// separately with `output_channels / groups` outputs. This is known as /// depthwise convolution. -pub fn conv( +pub fn conv( pool: &TensorPool, - input: TensorView, - kernel: TensorView, - bias: Option, + input: TensorView, + kernel: TensorView, + bias: Option>, padding: Padding, groups: usize, strides: &[usize], dilations: &[usize], -) -> Result { +) -> Result, OpError> +where + X: std::ops::Mul + GemmInT, + W: GemmInT, + Y: Default + std::ops::AddAssign + GemmOutT, + GemmExecutor: Default, + for<'a> VirtualIm2Col<'a, X>: VirtualMatrix, +{ // Handle 1D convolution by expanding to 2D and then removing the extra // dimension from the result. if let &[_n, _c, _w] = input.shape() { @@ -211,7 +221,7 @@ pub fn conv( let n_patches = out_h * out_w; let mut output = NdTensor::uninit_in(pool, [batch, out_c, n_patches]); - let gemm = GemmExecutor::new(); + let gemm = GemmExecutor::::default(); // Bias must be contiguous for use with `gemm_bias`. let bias = bias.map(|b| b.to_contiguous()); @@ -597,15 +607,20 @@ mod tests { /// Un-optimized reference implementation of convolution. /// /// This has the same interface as [conv]. - fn reference_conv( - input: TensorView, - kernel: TensorView, - bias: Option, + fn reference_conv( + input: TensorView, + kernel: TensorView, + bias: Option>, padding: Padding, groups: usize, strides: &[usize], dilations: &[usize], - ) -> Tensor { + ) -> Tensor + where + X: Copy + std::ops::Mul, + W: Copy, + Y: Copy + Default + std::ops::Add + std::ops::AddAssign, + { // If this is a 1D conv, insert a dummy H axis, perform a 2D convolution // and then remove the H axis from the result. if input.ndim() == 3 { @@ -666,11 +681,11 @@ mod tests { let chan_bias = if let Some(ref bias) = bias { bias[[out_chan]] } else { - 0.0 + Y::default() }; for out_y in 0..out_h { for out_x in 0..out_w { - let mut accum = 0.0; + let mut accum = Y::default(); for in_chan in in_chan_start..in_chan_end { for k_y in 0..k_h { for k_x in 0..k_w { diff --git a/src/ops/conv/depthwise.rs b/src/ops/conv/depthwise.rs index 962498b0..62930367 100644 --- a/src/ops/conv/depthwise.rs +++ b/src/ops/conv/depthwise.rs @@ -57,17 +57,21 @@ fn min_max_out_x_coords( /// /// When this function returns, all elements of `output` will have been /// initialized. -fn conv_2d_depthwise_block( - mut output: NdTensorViewMut, 3>, // C, H, W +fn conv_2d_depthwise_block( + mut output: NdTensorViewMut, 3>, // C, H, W chan_range: Range, - input: NdTensorView, // C, H, W - kernel: NdTensorView, // C, _, Kh, Kw - bias: Option>, + input: NdTensorView, // C, H, W + kernel: NdTensorView, // C, _, Kh, Kw + bias: Option>, padding: [usize; 4], strides: [usize; 2], dilations: [usize; 2], col_range_for_kernel_x: &[(Range, Range)], -) { +) where + X: Copy + std::ops::Mul, + W: Copy, + Y: Copy + Default + std::ops::AddAssign, +{ debug_assert_eq!(input.stride(2), 1, "last dim of input is not contiguous"); debug_assert_eq!(output.stride(2), 1, "last dim of output is not contiguous"); @@ -93,7 +97,11 @@ fn conv_2d_depthwise_block( let in_row_len = in_chan.size(1); let in_chan_data = in_chan.data().unwrap(); - let init_value = if let Some(bias) = bias { bias[[c]] } else { 0. }; + let init_value = if let Some(bias) = bias { + bias[[c]] + } else { + Y::default() + }; // The loops here are ordered so that the inner-most loop is as // efficient as possible and runs for as long as possible over a @@ -105,7 +113,7 @@ fn conv_2d_depthwise_block( for x in out_row.iter_mut() { x.write(init_value); } - let out_row: &mut [f32] = unsafe { std::mem::transmute(out_row) }; + let out_row: &mut [Y] = unsafe { std::mem::transmute(out_row) }; for k_y in 0..k_h { let in_y = out_y * stride_h + k_y * dilation_y; @@ -140,16 +148,21 @@ fn conv_2d_depthwise_block( /// Depthwise convolutions operate over a single input/output channel at /// a time and hence the transformation of convolution to matrix multiplication /// doesn't pay off. An optimized direct method works better. -pub fn conv_2d_depthwise( +pub fn conv_2d_depthwise( pool: &TensorPool, - input: &NdTensorView, - kernel: &NdTensorView, - bias: Option>, + input: &NdTensorView, + kernel: &NdTensorView, + bias: Option>, padding: [usize; 4], strides: [usize; 2], dilations: [usize; 2], out_hw: [usize; 2], -) -> Tensor { +) -> Tensor +where + X: Copy + std::ops::Mul, + W: Copy, + Y: Copy + Default + std::ops::AddAssign, +{ let [batch, _in_c, _in_h, in_w]: [usize; 4] = input.shape(); let [out_c, _, _k_h, k_w]: [usize; 4] = kernel.shape(); let [_pad_top, pad_left, _pad_bottom, _pad_right] = padding; diff --git a/src/ops/conv/im2col.rs b/src/ops/conv/im2col.rs index fc03eaa1..86cf31e4 100644 --- a/src/ops/conv/im2col.rs +++ b/src/ops/conv/im2col.rs @@ -41,8 +41,8 @@ struct ColOffsets { /// The transform is virtual because the matrix is not actually materialized /// in memory. Instead blocks of it are produced on-demand during a matrix /// multiplication operation. -pub struct VirtualIm2Col<'a> { - image: NdTensorView<'a, f32, 3>, +pub struct VirtualIm2Col<'a, T> { + image: NdTensorView<'a, T, 3>, /// Map of im2col row index to input image coordinate, premultiplied with /// the corresponding stride. @@ -69,18 +69,18 @@ pub struct VirtualIm2Col<'a> { gemm_kernel: KernelType, } -impl<'a> VirtualIm2Col<'a> { +impl<'a, T> VirtualIm2Col<'a, T> { /// Create a virtual im2col matrix from a [C, H, W] input tensor and /// convolution parameters. pub fn new( gemm_kernel: KernelType, - image: NdTensorView<'a, f32, 3>, + image: NdTensorView<'a, T, 3>, kernel: [usize; 2], padding: [usize; 4], strides: [usize; 2], dilations: [usize; 2], panel_width: usize, - ) -> VirtualIm2Col { + ) -> VirtualIm2Col { // Ensure image has at least one cell. assert!(image.len() > 0); @@ -179,7 +179,9 @@ impl<'a> VirtualIm2Col<'a> { max_x_offset, } } +} +impl<'a> VirtualIm2Col<'a, f32> { /// Pack part of an image according to the requirements of /// [VirtualMatrix::pack_b]. /// @@ -310,7 +312,7 @@ const KERNEL_FMA_NR: usize = 16; const KERNEL_WASM_NR: usize = 8; // Safety: `pack_b` initializes the entire buffer passed to it. -unsafe impl<'a> VirtualMatrix for VirtualIm2Col<'a> { +unsafe impl<'a> VirtualMatrix for VirtualIm2Col<'a, f32> { fn rows(&self) -> usize { self.row_offsets.chan.len() }