diff --git a/.gitattributes b/.gitattributes index e5b821f51..dda5bfc74 100644 --- a/.gitattributes +++ b/.gitattributes @@ -60,3 +60,14 @@ c_reference/models/q_scut_head_b_face4_model/mbconv2.h filter=lfs diff=lfs merge c_reference/models/q_scut_head_b_face4_model/mbconv4.h filter=lfs diff=lfs merge=lfs -text c_reference/models/q_scut_head_b_face4_model/rnn2.h filter=lfs diff=lfs merge=lfs -text c_reference/models/q_scut_head_b_face4_model/detection2.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/keyword_spotting_io_1.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/keyword_spotting_io_2.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/keyword_spotting_io_3.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/conv1d/conv1d_regular/conv_param.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/precnn_params.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/postcnn_params.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/rnn_params.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/rnn_bricked/rnn_params.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/rnn_bricked/rnn_bricked_io.h filter=lfs diff=lfs merge=lfs -text diff --git a/c_reference/include/conv1d.h b/c_reference/include/conv1d.h new file mode 100644 index 000000000..e92f78727 --- /dev/null +++ b/c_reference/include/conv1d.h @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#ifndef __CONV1D_H__ +#define __CONV1D_H__ + +/* All the matrices/tensors are stored in the row major format. + + NOTES for the conv layers. +-> The conv1d & conv1d_lr layers work for all cases and can be used unconstrained. + There are no hard constraints for the parallel version, but a few points regarding its optimal usage are given below. +-> Dilation = 1 (no dilation) for all cases. +-> For the non-depthwise cases, store the matrices as described below. Permutation might be necessary. +-> The low-rank decomposition cannot be applied to the depthwise weight matrices. This is due to the out_channels/in_channels = 0 constarint imposed by the depthwise convolution. + For full-rank this is satisfied since out_channels = in_channels. + But, when the matrix is decomposed, the constarint is violated (since rank < out_channels ; rank is not divisible by in_channels). + Hence due to the decomposition being theoretically impossible, we have not provided the support. + However we suggest a less-efficient alternative => First pre-compute the weights W = W2 * W1 and then use a regular conv. +-> For the parallel cases, the non-overlapping cases of the convolution are computed parallelly using MatMul (since the blocked MatMul is faster). + This howver is only valid for when the filter is fully in the input. There would be no-overlapping for the edge cases. + Hence the MatVec code(regular code) is used to calculate these cases. + + Important points regarding parallel versions. +-> Due to the above reason, the parallel layers is only recommended for large in_time inputs. + This should typically be for in_time (without the padding) > 2 * num_steps_one_row + stride. Else there would not be enough time-steps to efficiently parallelise. + We need at least 2 rows for a good a MatMul performace. In the worst case the starting time step would be (stride - 1). Hence we choose 2 * num_steps_one_row + stride as the threshold. + For the short input cases, the code will skip the MatMul computation and use MatVec instead (but the MatMul-variable computation overhead would remain). + For such cases, the MatVec code (conv1d and conv1d_lr) would work more efficiently due to the lower RAM usage and lack of any major overheads. +-> There is no support for depthwise for conv1d_parallel. + The regular convolution acts on all the channels while the depthwise acts only on one channel at a time. + This results in a non-contiguos memory access. MatMul would need to process multiple such time-steps, while the MatVec would only need to process one. + Hence, the MatVec would be able to enter the next channel earlier and would work much faster. + While the MatMul would have cache misses (when dealing with the small chache size of edge devices). +*/ + +/** + * @brief Model parameters for the 1D Convolution Layer. + * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1]. + * @var B pointer to the bias vector, original shape = [out_channels]. + * @var depthwise flag for deciding between regular(=0) and depthwise(=1) conv. + */ +typedef struct ConvLayers_Params { + const float* const W; + const float* const B; + unsigned depthwise; +} ConvLayers_Params; + +/** + * @brief Model definition for the 1D Convolution Layer. Currently only for dilation = 1. + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels. + * @param[in] out_time number of time steps in the output. + * @param[in] out_channels number of output channels for the output of the conv layer. + * NOTE: out_channels = in_channels for depthwise. This is set manually in the function. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. + * @param[in] padding padding applied to the input before the conv is performed. + * NOTE: padding is applied to both the starting and ending of the input, along the time axis. + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1). + * @param[in] kernel_size kernel size of the conv filter. + * @param[in] params weights, bias and other essential parameters used to describe the layer. + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. + */ +int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation); + +/** + * @brief Model parameters for the 1D Parallel Convolution Layer. + * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1]. + * @var B pointer to the bias vector, original shape = [out_channels]. + * @var block_size block/tile size for the cache. Used for tiled MatMul. + */ +typedef struct ConvLayers_Parallel_Params { + const float* const W; + const float* const B; + unsigned block_size; +} ConvLayers_Parallel_Params; + +/** + * @brief Model definition for the 1D Parallel Convolution Layer. Currently only for dilation = 1. No depthwise. + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels. + * @param[in] out_time number of time steps in the output. + * @param[in] out_channels number of output channels for the output of the conv layer. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis. + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1). + * @param[in] kernel_size kernel size of the conv filter. + * @param[in] params weights, bias and other essential parameters used to describe the layer. + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. + */ +int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation); + +/** + * @brief Model parameters for the 1D Low Rank Convolution Layer. + * @var W1 pointer to the flattened 1st low-rank component of the weights, original shape = [out_channels, rank]. For depthwise out_channels = in_channels. + * @var W2 pointer to the flattened 2nd low-rank component of the weights, original shape for regular = [rank, kernel_size, in_channels], shape for depthwise = [rank, kernel_size, 1]. + * @var B pointer to the flattened bias vector for the convolution, original shape = [out_channels]. + * @var rank rank of the weight tensor. A low-rank decomposition typically used to reduce computation and storage. + */ +typedef struct ConvLayers_LR_Params { + const float* const W1; + const float* const W2; + const float* const B; + unsigned rank; +} ConvLayers_LR_Params; + +/** + * @brief Model definition for the 1D Low-Rank Convolution Layer. Currently only for dilation = 1. + * @brief Low-Rank and depthwise are incompatible as the low-rank decomposition of the weight matrix violates the depthwise conditions (out_channels % groups = 0, where groups = in_channels). + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels. + * @param[in] out_time number of time steps in the output. + * @param[in] out_channels number of output channels for the output of the conv layer. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis. + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1). + * @param[in] kernel_size kernel size of the conv filter. + * @param[in] params weights, bias and other essential parameters used to describe the layer. + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. + */ +int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation); + +/** + * @brief Model parameters for the 1D Low Rank Parallel Convolution Layer. + * @var W1 pointer to the flattened 1st low-rank component of the weights, original shape = [out_channels, rank]. For depthwise out_channels = in_channels. + * @var W2 pointer to the flattened 2nd low-rank component of the weights, original shape for regular = [rank, kernel_size, in_channels], shape for depthwise = [rank, kernel_size, 1]. + * @var B pointer to the flattened bias vector for the convolution, original shape = [out_channels]. + * @var rank rank of the weight tensor. A low-rank decomposition typically used to reduce computation and storage. + * @var block_size_to_lr block/tile size for the cache. Used for tiled MatMul. Used for the input -> low-rank computation. + * @var block_size_from_lr block/tile size for the cache. Used for tiled MatMul. Used for the low-rank -> output computation. + */ +typedef struct ConvLayers_LR_Parallel_Params { + const float* const W1; + const float* const W2; + const float* const B; + unsigned rank; + unsigned block_size_to_lr; + unsigned block_size_from_lr; +} ConvLayers_LR_Parallel_Params; + +/** + * @brief Model definition for the 1D Low-Rank Parallel Convolution Layer. Currently only for dilation = 1. + * @brief Low-Rank and depthwise are incompatible as the low-rank decomposition of the weight matrix violates the depthwise conditions (out_channels % groups = 0, where groups = in_channels). + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels. + * @param[in] out_time number of time steps in the output. + * @param[in] out_channels number of output channels for the output of the conv layer. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis. + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1). + * @param[in] kernel_size kernel size of the conv filter. + * @param[in] params weights, bias and other essential parameters used to describe the layer. + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. + */ +int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation); + +// Auxiliary Layers. +/** + * @brief Model definition for the 1D Average Pooling Layer. Currently only for dilation = 1. + * @param[out] output_signal pointer to the output signal, size = out_time * in_channels. Provide Null/0 in case of in-place computation. + * NOTE: out_channels == in_channels for avgpool. + * @param[in] out_time number of time steps in the output. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. The output will have the same number of channels. + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis. + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1). + * @param[in] kernel_size kernel size of the pool filter. + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity. + * 0: none. + * 1: sigmoid + * 2: tanh. + * 3: relu. + */ +int avgpool1d(float* output_signal, unsigned out_time, const float* input_signal, + unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, unsigned stride, unsigned activation); + +/** + * @brief Model definition for the 1D batch Normalization Layer. + * @param[out] output_signal pointer to the output signal, size = out_time * in_channels. Provide Null/0 in case of in-place computation. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. The output will have the same number of channels. + * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. if affine_config = 2, then pass a NULL/0. + * @param[in] var pointer to the variance for the batch normalization, size = in_channels. if affine_config = 2, then pass a NULL/0. + * @param[in] affine_config whether the affine operations are applied. + * if affine_config = 0, then only mean and var are used. + * if affine_config = 1, then mean, var, gamma and beta are used for the final computation. + * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var). + * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed. + * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Provide Null/0 if affine_config is 0. + * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Provide Null/0 if affine_config is 0. + * @param[in] in_place in-place computation of the batchnorm i.e. the output is stored in-place of the input signal. Storage efficient. + * @param[in] eps a very small +ve value to avoid division by 0. For the default value, assign = 0.00001. + */ +int batchnorm1d(float* output_signal, float* input_signal, + unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma , const float* const beta, + unsigned in_place, float eps); + +#endif diff --git a/c_reference/include/dscnn.h b/c_reference/include/dscnn.h new file mode 100644 index 000000000..1833d0813 --- /dev/null +++ b/c_reference/include/dscnn.h @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#ifndef __DSCNN_H__ +#define __DSCNN_H__ + +// Function pointer for the Conv layer to be passed as a parameter. (conv1d or conv1d_lr only). +typedef int (*conv_layer)(float*, unsigned, unsigned, const float*, + unsigned, unsigned, unsigned, unsigned, + const void*, unsigned, unsigned); + +/** + * @brief Model definition for the 1D Convolution block applied before the RNN. + * @brief sub-layers : batchnorm1d -> conv1d_lr. + * @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] cnn function pointer for the CNN layer. (any of the conv layers can be passed with appropriate params). + * @param[in] in_time number of time steps in the input_signal. + * @param[in] in_channels number of input channels. + * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2. + * @param[in] var pointer to the variance for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2. + * @param[in] affine_config whether the affine operations are applied. + * if affine_config = 0, then only mean and var are used. + * if affine_config = 1, then mean, var, gamma and beta are used for the final computation. + * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var). + * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed. + * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0. + * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0. + * @param[in] in_place in-place computation check for the batchnorm. Storage efficient. + * @param[in] cnn_hidden hidden state/out_channels dimensions for the low-rank CNN. The final channel size of this block. + * @param[in] cnn_padding padding for the low-rank CNN layer. Note: applied to both sides of the input. + * @param[in] cnn_kernel_size kernel size of the low-rank CNN. + * @param[in] cnn_params weights, bias and other essential parameters for the low-rank CNN. + * @param[in] cnn_stride stride factor for the low-rank CNN. + * @param[in] cnn_activation an integer to choose the type of activation function. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. + */ +int phon_pred_lr_cnn(float* output_signal, float* input_signal, + conv_layer cnn, unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place, + unsigned cnn_hidden, unsigned cnn_padding, unsigned cnn_kernel_size, + const void* cnn_params, unsigned cnn_stride, unsigned cnn_activation); + +/** + * @brief Model definition for the 1D Convolution block applied after the RNN. + * @brief sub-layers : custom nonlinearity(semi_sigmoid_tanh) -> batchnorm1d -> conv1d_depth -> conv1d_lr -> avgpool1d. + * @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] point_cnn function pointer for the point-wise CNN. (any of the conv layers can be passed with appropriate params). + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. + * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2. + * @param[in] var pointer to the variance for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2. + * @param[in] affine_config whether the affine operations are applied. + * if affine_config = 0, then only mean and var are used. + * if affine_config = 1, then mean, var, gamma and beta are used for the final computation. + * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var). + * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed. + * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0. + * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0. + * @param[in] in_place in-place computation of the batchnorm. Storage efficient. + * @param[in] depth_cnn_padding padding for the depth CNN layer. Note: applied to both sides of the input to the depth CNN. + * @param[in] depth_cnn_kernel_size kernel size of the depth CNN. + * @param[in] depth_cnn_params weights, bias and other essential parameters used to describe the depth CNN. + * @param[in] depth_cnn_stride stride factor for the depth CNN. + * @param[in] depth_cnn_activation an integer to choose the type of activation function. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. + * @param[in] point_cnn_hidden hidden state/out_channels dimensions for the point CNN. The final channel size of this block. + * @param[in] point_cnn_padding padding for the point CNN layer. Note: applied to both sides of the input to the point CNN. + * @param[in] point_cnn_kernel_size kernel size of the point CNN. + * @param[in] point_cnn_params weights, bias and other essential parameters used to describe the point CNN. + * @param[in] point_cnn_stride stride factor for the point CNN. + * @param[in] point_cnn_activation an integer to choose the type of activation function. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. + * @param[in] pool_padding padding for the pool layer. Note: applied to both sides of the input to the pool. + * @param[in] pool_kernel_size kernel size of the pool. + * @param[in] pool_stride stride factor for the pool. + * @param[in] pool_activation an integer to choose the type of activation function. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. + */ +int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, + conv_layer point_cnn, unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place, + unsigned depth_cnn_padding, unsigned depth_cnn_kernel_size, + const void* depth_cnn_params, unsigned depth_cnn_stride, unsigned depth_cnn_activation, + unsigned point_cnn_hidden, unsigned point_cnn_padding, unsigned point_cnn_kernel_size, + const void* point_cnn_params, unsigned point_cnn_stride, unsigned point_cnn_activation, + unsigned pool_padding, unsigned pool_kernel_size, unsigned pool_stride, unsigned pool_activation); + +#endif diff --git a/c_reference/include/rnn_bricked.h b/c_reference/include/rnn_bricked.h new file mode 100644 index 000000000..a7e2d2658 --- /dev/null +++ b/c_reference/include/rnn_bricked.h @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#ifndef __RNN_BRICKED_H__ +#define __RNN_BRICKED_H__ + +/* All the matrices are stored in the row major format. + + NOTES for using the layers. +-> Single-directional Computation. + While using the bricked fastgrnn layers, the user needs to adhered to the two following constraints. + 1) in_time % hop = 0. + 2) fwd_window % hop = 0 and bwd_window % hop = 0. + + Violation of the above two constraints (1 & 2), will cause segmentation faults. + The layers first compute all the Wx steps and then compute Uh for all the windows parallelly. + Hence, the user needs to adhered to the constraints 1 & 2. + +-> Bi-directional Computation. + For bi-directional cases, there are 2 additionally constraints that would need to be followed. + A) sample_first_brick and sample_last_brick = 1. + B) An offset of rnn_hidden would need to be given to the output_signal pointer during the backward function call. + Each function will only process its given context(forward/backward). The other context will need to be called separately. + E.g : 1st step -> forward(output, ..., input, ..., bi-direction=1, ...). + 2nd step -> backward(output + rnn_hidden, ..., input, ..., bi-direction=1, ...). + + The two extra constraints (A & B) are only for bi-directional cases and can be ignored if only forward (or only backward) is used. + Violating the conditions would cause index mis-matches or data corruption. + If the first (last) brick is not sampled, the first few (last few) time steps would be missing in the forward (backward) result . + If the offset is not passed during the backward function call, the backward pass will overwrite the forward result (bi-directional case only). +*/ + +/** + * @brief Model parameters for the 1D Convolution Layer. + * @var W1 pointer to first low-rank component of W. shape = [rank * in_dims]. + * @var W2 pointer to second low-rank component of W. shape = [rnn_hidden * rank]. + * @var wRank rank of W matrix. + * @var U1 pointer to first low-rank component of U. shape = [rank * rnn_hidden]. + * @var U2 pointer to second low-rank component of U. shape = [rnn_hidden * rank]. + * @var uRank rank of U matrix. + * @var Bg pointer to bias for sigmoid. + * @var Bh pointer to bias for tanh. + * @var sigmoid_zeta first weight parameter for update from input from next step. + * @var sigmoid_nu second weight parameter for update from input from next step. + * @var block_size_w_to_lr block/tile size for the cache. Used for tiled MatMul. For W1 * x. + * @var block_size_w_from_lr block/tile size for the cache. Used for tiled MatMul. For W2 * result(W1 * x). + * @var block_size_u_to_lr block/tile size for the cache. Used for tiled MatMul. For U1 * h. + * @var block_size_u_from_lr block/tile size for the cache. Used for tiled MatMul. For U2 * result(U1 * h). + */ +typedef struct BrickedFastGRNN_LR_Params { + float* W1; + float* W2; + unsigned wRank; + float* U1; + float* U2; + unsigned uRank; + float* Bg; + float* Bh; + float sigmoid_zeta; + float sigmoid_nu; + unsigned block_size_w_to_lr; + unsigned block_size_w_from_lr; + unsigned block_size_u_to_lr; + unsigned block_size_u_from_lr; +} BrickedFastGRNN_LR_Params; + +/** Forward Bricking and application of the forward RNN for an input signal. + * @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden. + * @param[in] rnn_hidden output dimension for the current cell. + * @param[in] input_signal pointer to input signal. size = in_time * in_dims. + * @param[in] in_time number of input time steps. + * @param[in] in_dims input dimensions. + * @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick). + * @param[in] hop hop distance for between bricks. + * @param[in] params pointer to the parameters for the RNN. + * @param[in] bi_direction determine if the ouput if for a bi-directional RNN. + * @param[in] sample_first_brick determine if the 1st brick should also be sampled. + * -> if = 0, only the last hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1. + * -> if = 1, for the 1st brick, we sample every hop index(similar to ::hop). For all the bricks(including the 1st) we sample the final hiddens state. out_time = in_time/hop + 1. + */ +int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_first_brick); + +/** Backward Bricking and application of the backward RNN for an input signal. + * @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden. + * @param[in] rnn_hidden output dimension for the current cell. + * @param[in] input_signal pointer to input signal. size = in_time * in_dims. + * @param[in] in_time number of input time steps. + * @param[in] in_dims input dimensions. + * @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick). + * @param[in] hop hop distance for between bricks. + * @param[in] params pointer to the parameters for the RNN. + * @param[in] bi_direction determine if the ouput if for a bi-directional RNN. + * @param[in] sample_last_brick determine if the last brick should also be sampled + * -> if = 0, only the first(last in reverse) hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1. + * -> if = 1, for the last brick, we sample every hop index in reverse(similar to ::hop in reverse). For all the bricks(including the last) we sample the first hiddens state(last in reverse). out_time = in_time/hop + 1. + */ +int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_last_brick); + +#endif diff --git a/c_reference/include/utils.h b/c_reference/include/utils.h index 26d5242ba..37a821a90 100644 --- a/c_reference/include/utils.h +++ b/c_reference/include/utils.h @@ -31,6 +31,80 @@ void matVec(const float* const mat, const float* const vec, float alpha, float beta, float* const ret); +/* + Matrix-vector multiplication with a row offset. + This function was developed primarily for the conv1d function. This helps bypass the permutation of the time and channel axis. + ret is of size nrows, vec is of size ncols. + mat is of size nrows * ncols, stored in row major. + depthwise is to change the matVec to depthwise specific convolutions. + row_stride is the offset factor between two adjacent rows. + Note : This matrix-vector multiplication is useful for matrices where a certain number of columns are dropped. + For a normal matVec case, this value will be ncols. + Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. + Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. + Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. + For this eg ncols will be 100 and row_stride will be 400. + vec_stride is the offset fector between 2 elements in a vector i.e. the elements of a vector are placed at "n" intervals. + For a normal matVec case, this value will be 1. + Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. + Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. + Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. + So it's possible to enter a 400 length vector and consider every 4th element. + So it's possible to enter a 400 length vector and consider every 4th element. + So it's possible to enter a 400 length vector and consider every 4th element. + For this ncols will be 100 and vec_stride will be 4. +*/ +void offset_matVec_conv1d(const float* mat, const float* vec, + unsigned nrows, unsigned ncols, + unsigned row_stride, unsigned vec_stride, + unsigned depthwise, float* ret); + +/* + Tiled (cache-blocked) implementation of the Matrix Multiplication. + Note: If only the MatMul output is needed, then please use calloc to initialize the output. + An alternative is to use malloc, followed by memset 0. + There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix. + If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly. + This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed. + matA first matrix; shape = [nrows, ncommon]. + matB second matrix; shape = [ncommon, ncols]. + nrows number of rows in the first matrix. + ncommon number of columns in the first matrix/number of rows in the second matrix. + ncols number of columns in the second matrix. + total_comm_A the actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored. + total_cols_B the actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + ret matrix multiplication output. shape = [nrows, ncols]. + block_size tile/block size for optimal cache performance. A hardware specific parameter. +*/ +void tiledMatMul_float(const float* const matA, const float* const matB, + unsigned nrows, unsigned ncommon, unsigned ncols, + unsigned total_comm_A, unsigned total_cols_B, + float* const ret, unsigned block_size); + +/* + Tiled (cache-blocked) implementation of the Matrix Multiplication, but with matB stored in the transposed format. + The result will the same as the regular MatMul but the matrix B provided will be pre-transposed (before the storage or usage). + Note: If only the MatMul output is needed, then please use calloc to initialize the output. + An alternative is to use malloc, followed by memset 0. + There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix. + If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly. + This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed. + matA first matrix; shape = [nrows, ncommon]. + matB second matrix; shape = [ncols, ncommon]. + nrows number of rows in the first matrix. + ncommon number of columns in the first matrix/number of rows in the second matrix. + ncols number of columns in the second matrix. + total_comm_A the actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored. + total_comm_B the actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + since matB is transposed the columns are now the ncomm axis. + ret matrix multiplication output. shape = [nrows, ncols]. + block_size tile/block size for optimal cache performance. A hardware specific parameter. +*/ +void transposed_tiledMatMul(const float* const matA, const float* const matB, + unsigned nrows, unsigned ncommon, unsigned ncols, + unsigned total_comm_A, unsigned total_comm_B, + float* const ret, unsigned block_size); + // scaled vector addition: ret = scalar1 * vec1 + scalar2 * vector2 void v_add(float scalar1, const float* const vec1, float scalar2, const float* const vec2, @@ -54,4 +128,14 @@ unsigned argmax(const float* const vec, unsigned len); // ret[i] = exp(input[i]) / \sum_i exp(input[i]) void softmax(const float* const input, unsigned len, float* const ret); +/* Custom non-linear layer for the phoneme detection model. It can be used for other time-series problems if necessary. + output_signal pointer to the output signal, size = out_time * (in_channels / 2). + input_signal pointer to the input signal. size = in_time * in_channels. + in_time number of input time steps. + in_channels number of input channels. The output will have the half the number of input channels. + Necessary for in_channels % 2 == 0. + */ +void semi_sigmoid_tanh(float* output_signal, const float* const input_signal, + unsigned in_time, unsigned in_channels); + #endif diff --git a/c_reference/src/Makefile b/c_reference/src/Makefile index 8fc27bd65..7f7e79941 100644 --- a/c_reference/src/Makefile +++ b/c_reference/src/Makefile @@ -6,7 +6,13 @@ include ../config.mk INCLUDE_DIR=../include IFLAGS = -I $(INCLUDE_DIR) -all: utils.o fastgrnn.o classifier.o rnnpool.o quantized_utils.o quantized_fastgrnn.o quantized_rnnpool.o quantized_mbconv.o +all: dscnn.o conv1d.o utils.o fastgrnn.o classifier.o rnnpool.o quantized_utils.o quantized_fastgrnn.o quantized_rnnpool.o quantized_mbconv.o rnn_bricked.o + +dscnn.o : dscnn.c + $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ + +conv1d.o : conv1d.c + $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ utils.o: utils.c $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ @@ -20,6 +26,9 @@ classifier.o: classifier.c rnnpool.o: rnnpool.c $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ +rnn_bricked.o: rnn_bricked.c + $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ + quantized_utils.o: quantized_utils.c $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ diff --git a/c_reference/src/conv1d.c b/c_reference/src/conv1d.c new file mode 100644 index 000000000..552abb78a --- /dev/null +++ b/c_reference/src/conv1d.c @@ -0,0 +1,610 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include "conv1d.h" +#include "utils.h" + +int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation) { + + const ConvLayers_LR_Params* tparams= (ConvLayers_LR_Params*) params; + + // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding. + unsigned rank = tparams->rank; + // Buffer for W2 out. + float* temp_rank_out = (float*)malloc(rank * sizeof(float)); + // Buffer for W1 out. + float* temp_out = (float*)malloc(out_channels * sizeof(float)); + for (unsigned t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; + t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { + unsigned t_index = t_out * out_channels; + + if ((t_in_start >= padding) && (t_in_end < (in_time + padding))) { + // Filter fully inside the input. Kept as the initial condition, since this is the most common one. + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, kernel_size * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else if ((t_in_start < padding) && (t_in_end >= padding)) { + // Filter partially entered the input. + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. + offset_matVec_conv1d(tparams->W2 + (padding - t_in_start) * in_channels, + input_signal, rank, + (t_in_end - padding + 1) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { + // Filter partially exited the input. + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, (in_time + padding - t_in_start) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else { + // Filter completely in the padding region. + // The filter is either fully outside the input or has not yet entered the input. + memset(output_signal + t_index, 0, out_channels * sizeof(float)); + } + for (unsigned co = 0; co < out_channels; co++) { + // Post-Conv activation. More activation functions can be added should the necessity arise. + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; + } + } + } + free(temp_out); + free(temp_rank_out); + return 0; +} + +int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation) { + + unsigned ncols = kernel_size * in_channels, num_iter = 0, num_steps_one_row = 0; + // Calculate the number of time steps in one row for the first non-overlapping instance. + while (num_steps_one_row < kernel_size) { + num_steps_one_row += stride; + num_iter++; + } + unsigned total_in_cols = num_steps_one_row * in_channels; + + const ConvLayers_LR_Parallel_Params* tparams = (ConvLayers_LR_Parallel_Params*) params; + // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding. + // Buffer to hold the output. For corner cases, this will be relatively big. + // But will be needed for the central condition (filter inside input). + // If there are not enough time steps to linearise into one row, then allocate only 1 time step. + unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? + in_time / num_steps_one_row : 1; + unsigned rank = tparams->rank; + // Buffer for W2 out. + float* temp_rank_out = (float*)malloc(buffer_steps * rank * sizeof(float)); + // Buffer for W1 out. + float* temp_out = (float*)malloc(buffer_steps * out_channels * sizeof(float)); + + unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here. + for (t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; + t_in_start < padding && t_out < out_time; + t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_end < padding) { + // Filter outside the input region and in the padded region. + memset(output_signal + t_out * out_channels, 0, + out_channels * sizeof(float)); + } + else { //(t_in_end >= padding). + // Filter partially entered the input. + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. + offset_matVec_conv1d(tparams->W2 + (padding - t_in_start) * in_channels, + input_signal, rank, (t_in_end - padding + 1) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + } + // The main part => the filter is fully inside the input. We can think of the non-overlapping cases as parallel cases. + // Each of the iterations are for the kernel striding to the next point till the filter is out of the overlapping region. + // Hence we use the num_steps_one_row for calculating the number of time steps to be linearized in one row. + // Using the above logic, we can convert the MatVec opeartion into a MatMul operation. + // Ideally both implementation would be the same. However for edge devices the matMul was found to be faster matVec (both tilied). + // Skip if atleast 2 rows cannot be formed. The condition 2 * num_steps_one_row + stride is the worst case criteria. + // The MatVec will be used for the computation in-case the following block is skipped. + if (in_time > ((num_steps_one_row << 1) + stride)) { + t_in_start -= padding; // remove the padding offset temporarily. + t_in_end -= padding; // Used to keep track of the final processed index. + for (unsigned iter = 0; (iter < num_iter) && (t_out < out_channels); + iter++, t_in_start += stride, t_out++) { + unsigned in_rows = (in_time - t_in_start) / num_steps_one_row; + memset(temp_rank_out, 0, buffer_steps * rank * sizeof(float)); + memset(temp_out, 0, buffer_steps * out_channels * sizeof(float)); + if (t_in_end < (t_in_start + ((in_rows - 1) * num_steps_one_row))) { + // t_in_end is used to find the furthest time step was used in the MatMul calculation. + // This value will be used for calculating the index for the final section of the processing. + t_in_end = ((in_rows - 1) * num_steps_one_row) + t_in_start + stride; + } + transposed_tiledMatMul(input_signal + t_in_start * in_channels , tparams->W2, + in_rows, ncols, rank, total_in_cols, ncols, + temp_rank_out, tparams->block_size_to_lr); + transposed_tiledMatMul(temp_rank_out , tparams->W1, + in_rows, rank, out_channels, rank, rank, + temp_out, tparams->block_size_from_lr); + // Copy all the data into the output. + float* output_offset = (float*)output_signal + t_out * out_channels; + float* temp_offset = (float*)temp_out; + unsigned t_iter = in_rows, offset_factor_for_out = num_iter * out_channels; + while (t_iter--) { + memcpy(output_offset, temp_offset, out_channels * sizeof(float)); + output_offset += offset_factor_for_out; + temp_offset += out_channels; + } + } + // Initialize the time iterators. + // Use the stored value in t_in_end to calculate the iterators. + t_in_start = t_in_end + padding; // Add the padding and stride offsets again. + t_in_end = t_in_start + kernel_size - 1; + t_out = t_in_start / stride; + } + for (; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_start < (in_time + padding) && (t_in_end < (in_time + padding))) { + // Filter fully in the input but very close to the edges. + // Due to the num_steps_one_row divisibility usage in the parallel step, some computations would be skipped. + // Incase the MatMul is skipped, this block will be used to compute the results. + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, kernel_size * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { + // Filter partially exited the input. + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, (in_time + padding - t_in_start) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + else { + // Filter completely outside the input and in the padding region. + memset(output_signal + t_out * out_channels, + 0, out_channels * sizeof(float)); + } + } + // Bias and activation. + for (t_out = 0; t_out < out_time; t_out++) { + unsigned t_index = t_out * out_channels; + for (unsigned co = 0; co < out_channels; co++) { + // Post-Conv activation. More activation functions can be added should the necessity arise. + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; + } + } + } + free(temp_out); + return 0; +} + +int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation) { + + const ConvLayers_Params* tparams= (ConvLayers_Params*) params; + unsigned vec_stride = 1, cols_scale = in_channels; + if (tparams->depthwise) { + vec_stride = in_channels; + out_channels = in_channels; + cols_scale = 1; + } + + // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding. + float* temp_out = (float*)malloc(out_channels * sizeof(float)); + for (unsigned t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; + t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { + unsigned t_index = t_out * out_channels; + + if ((t_in_start >= padding) && (t_in_end < (in_time + padding))) { + // Filter fully inside the input. Kept as the initial condition, since this is the most common one. + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, kernel_size * cols_scale, + kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else if ((t_in_start < padding) && (t_in_end >= padding)) { + // Filter partially entered the input. + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. + offset_matVec_conv1d(tparams->W + (padding - t_in_start) * cols_scale, + input_signal, out_channels, (t_in_end - padding + 1) * cols_scale, + kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { + // Filter partially exited the input. + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, (in_time + padding - t_in_start) * cols_scale, + kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else { + // Filter completely in the padding region. + // The filter is either fully outside the input or has not yet entered the input. + memset(output_signal + t_index, 0, out_channels * sizeof(float)); + } + for (unsigned co = 0; co < out_channels; co++) { + // Post-Conv activation. More activation functions can be added should the necessity arise. + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; + } + } + } + free(temp_out); + return 0; +} + +int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation) { + + unsigned ncols = kernel_size * in_channels, num_iter = 0, num_steps_one_row = 0; + // Calculate the number of time steps in one row for the first non-overlapping instance. + while (num_steps_one_row < kernel_size) { + num_steps_one_row += stride; + num_iter++; + } + unsigned total_in_cols = num_steps_one_row * in_channels; + + const ConvLayers_Parallel_Params* tparams = (ConvLayers_Parallel_Params*) params; + // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding. + // Buffer to hold the output. For corner cases, this will be relatively big. + // But will be needed for the central condition (filter inside input). + // If there are not enough time steps to linearise into one row, then allocate only 1 time step. + unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? + in_time / num_steps_one_row : 1; + float* temp_out = (float*)malloc(buffer_steps * out_channels * sizeof(float)); + unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here. + for (t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; + t_in_start < padding && t_out < out_time; + t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_end < padding) { + // Filter outside the input region and in the padded region. + memset(output_signal + t_out * out_channels, + 0, out_channels * sizeof(float)); + } + else { //(t_in_end >= padding). + // Filter partially entered the input. + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. + offset_matVec_conv1d(tparams->W + (padding - t_in_start) * in_channels, + input_signal, out_channels, (t_in_end - padding + 1) * in_channels, + ncols, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + } + // The main part => the filter is fully inside the input. We can think of the non-overlapping cases as parallel cases. + // Each of the iterations are for the kernel striding to the next point till the filter is out of the overlapping region. + // Hence we use the num_steps_one_row for calculating the number of time steps to be linearized in one row. + // Using the above logic, we can convert the MatVec opeartion into a MatMul operation. + // Ideally both implementation would be the same. However for edge devices the matMul was found to be faster matVec (both tilied). + // Skip if atleast 2 rows cannot be formed. The condition 2 * num_steps_one_row + stride is the worst case criteria. + // The MatVec will be used for the computation in-case the following block is skipped. + if (in_time > ((num_steps_one_row << 1) + stride)) { + t_in_start -= padding; // remove the padding offset temporarily. + t_in_end -= padding; // Used to keep track of the final processed index. + for (unsigned iter = 0; (iter < num_iter) && (t_out < out_channels); + iter++, t_in_start += stride, t_out++) { + unsigned in_rows = (in_time - t_in_start) / num_steps_one_row; + memset(temp_out, 0, buffer_steps * out_channels * sizeof(float)); + if (t_in_end < (t_in_start + ((in_rows - 1) * num_steps_one_row))) { + // t_in_end is used to find the furthest time step was used in the MatMul calculation. + // This value will be used for calculating the index for the final section of the processing. + t_in_end = ((in_rows - 1) * num_steps_one_row) + t_in_start + stride; + } + transposed_tiledMatMul(input_signal + t_in_start * in_channels , tparams->W, + in_rows, ncols, out_channels, total_in_cols, ncols, + temp_out, tparams->block_size); + // Copy all the data into the output. + float* output_offset = (float*)output_signal + t_out * out_channels; + float* temp_offset = (float*)temp_out; + unsigned t_iter = in_rows, offset_factor_for_out = num_iter * out_channels; + while (t_iter--) { + memcpy(output_offset, temp_offset, out_channels * sizeof(float)); + output_offset += offset_factor_for_out; + temp_offset += out_channels; + } + } + // Initialize the time iterators. + // Use the stored value in t_in_end to calculate the iterators. + t_in_start = t_in_end + padding; // Add the padding and stride offsets again. + t_in_end = t_in_start + kernel_size - 1; + t_out = t_in_start / stride; + } + for (; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_start < (in_time + padding) && (t_in_end < (in_time + padding))) { + // Filter fully in the input but very close to the edges. + // Due to the num_steps_one_row divisibility usage in the parallel step, some computations would be skipped. + // Incase the MatMul is skipped, this block will be used to compute the results. + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, kernel_size * in_channels, + kernel_size * in_channels, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { + // Filter partially exited the input. + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, (in_time + padding - t_in_start) * in_channels, + ncols, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + else { + // Filter completely outside the input and in the padding region. + memset(output_signal + t_out * out_channels, + 0, out_channels * sizeof(float)); + } + } + // Bias and activation. + for (t_out = 0; t_out < out_time; t_out++) { + unsigned t_index = t_out * out_channels; + for (unsigned co = 0; co < out_channels; co++) { + // Post-Conv activation. More activation functions can be added should the necessity arise. + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; + } + } + } + free(temp_out); + return 0; +} + +int avgpool1d(float* output_signal, unsigned out_time, const float* input_signal, + unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, unsigned stride, unsigned activation) { + + // Iterate over the time steps and average them. + float scale = 1.0/(float)kernel_size; // To avoid divisions. + for (unsigned t_in = 0, t_out = 0; t_out < out_time; t_out++, t_in += stride) { + for (unsigned ci = 0; ci < in_channels; ci++) { + float sum = 0; + for (unsigned tf = 0; tf < kernel_size; tf++) { + if (((t_in + tf) < padding) || ((t_in + tf) >= (in_time + padding))) { + continue; + } + else { + sum += (input_signal[((tf + t_in) - padding) * in_channels + ci]); + } + } + switch (activation) { + case 1 : + output_signal[t_out * in_channels + ci] = sigmoid(sum * scale); + break; + + case 2 : + output_signal[t_out * in_channels + ci] = tanh(sum * scale); + break; + + case 3 : + output_signal[t_out * in_channels + ci] = relu(sum * scale); + break; + + default : + output_signal[t_out * in_channels + ci] = sum * scale; + break; + } + } + } + return 0; +} + +int batchnorm1d(float* output_signal, float* input_signal, + unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma , const float* const beta, + unsigned in_place, float eps) { + float* ret = in_place ? (float*)input_signal : (float*)output_signal; + + // Check for affine_config. + // = 1 ; Use gamma, beta, mean and var. + // = 2 ; Use only gamma and beta. + // = 3 ; Use only mean and var. + if (affine_config == 1) { + while (in_time--) { + float* gamma_offset = (float*)gamma; + float* beta_offset = (float*)beta; + float* mean_offset = (float*)mean; + float* var_offset = (float*)var; + unsigned channels = in_channels; + + #ifdef LOOP_UNROLL + unsigned len_unroll = channels >> 2; + channels %= 4; + while (len_unroll--) { + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + } + #endif + + while (channels--) { + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + } + } + } + else if (affine_config == 2) { + while (in_time--) { + float* gamma_offset = (float*)gamma; + float* beta_offset = (float*)beta; + unsigned channels = in_channels; + + #ifdef LOOP_UNROLL + unsigned len_unroll = channels >> 2; + channels %= 4; + while (len_unroll--) { + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + } + #endif + + while (channels--) { + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + } + } + } + else { + while (in_time--) { + float* mean_offset = (float*)mean; + float* var_offset = (float*)var; + unsigned channels = in_channels; + + #ifdef LOOP_UNROLL + unsigned len_unroll = channels >> 2; + channels %= 4; + while (len_unroll--) { + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + } + #endif + + while (channels--) { + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + } + } + } + return 0; +} diff --git a/c_reference/src/dscnn.c b/c_reference/src/dscnn.c new file mode 100644 index 000000000..ef245837a --- /dev/null +++ b/c_reference/src/dscnn.c @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include "dscnn.h" +#include "conv1d.h" +#include "utils.h" + +int phon_pred_lr_cnn(float* output_signal, float* input_signal, + conv_layer cnn, unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place, + unsigned cnn_hidden, unsigned cnn_padding, unsigned cnn_kernel_size, + const void* cnn_params, unsigned cnn_stride, unsigned cnn_activation) { + + unsigned out_time = in_time - cnn_kernel_size + 2 * cnn_padding + 1; + if (in_place) { + // BatchNorm. + batchnorm1d(0, input_signal, + in_time, in_channels, + mean, var, affine_config, gamma, beta, + in_place, 0.00001); + // CNN. + cnn(output_signal, out_time, cnn_hidden, input_signal, + in_time, in_channels, cnn_padding, cnn_kernel_size, + cnn_params, cnn_stride, cnn_activation); + } + else { + // BatchNorm. + float* norm_out = (float*)malloc(in_time * in_channels * sizeof(float)); + batchnorm1d(norm_out, input_signal, + in_time, in_channels, + mean, var, affine_config, gamma, beta, + in_place, 0.00001); + // CNN. + cnn(output_signal, out_time, cnn_hidden, norm_out, + in_time, in_channels, cnn_padding, cnn_kernel_size, + cnn_params, cnn_stride, cnn_activation); + free(norm_out); + } + return 0; +} + +int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, + conv_layer point_cnn, unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place, + unsigned depth_cnn_padding, unsigned depth_cnn_kernel_size, + const void* depth_cnn_params, unsigned depth_cnn_stride, unsigned depth_cnn_activation, + unsigned point_cnn_hidden, unsigned point_cnn_padding, unsigned point_cnn_kernel_size, + const void* point_cnn_params, unsigned point_cnn_stride, unsigned point_cnn_activation, + unsigned pool_padding, unsigned pool_kernel_size, unsigned pool_stride, unsigned pool_activation) { + + // Activation. + float* act_out= (float*)malloc(in_time * (in_channels >> 1) * sizeof(float)); + semi_sigmoid_tanh(act_out, input_signal, in_time, in_channels); + + in_channels >>= 1; + float* depth_out; + unsigned out_time = in_time - depth_cnn_kernel_size + 2 * depth_cnn_padding + 1; + if (in_place) { + // Norm. + batchnorm1d(0, act_out, + in_time, in_channels, + mean, var, + affine_config, gamma, beta, + in_place, 0.00001); + // Depth CNN. + depth_out = (float*)malloc(out_time * in_channels * sizeof(float)); + conv1d(depth_out, out_time, 0, act_out, + in_time, in_channels, depth_cnn_padding, depth_cnn_kernel_size, + depth_cnn_params, depth_cnn_stride, depth_cnn_activation); + free(act_out); + } + else { + // Norm. + float* norm_out = (float*)malloc(in_time * in_channels * sizeof(float)); + batchnorm1d(norm_out, act_out, + in_time, in_channels, + mean, var, + affine_config, gamma, beta, + in_place, 0.00001); + free(act_out); + // Depth CNN. + depth_out = (float*)malloc(out_time * in_channels * sizeof(float)); + conv1d(depth_out, out_time, 0, norm_out, + in_time, in_channels, depth_cnn_padding, depth_cnn_kernel_size, + depth_cnn_params, depth_cnn_stride, depth_cnn_activation); + free(norm_out); + } + + // Point CNN. + in_time = out_time; + out_time = in_time - point_cnn_kernel_size + 2 * point_cnn_padding + 1; + float* point_out = (float*)malloc(out_time * point_cnn_hidden * sizeof(float)); + point_cnn(point_out, out_time, point_cnn_hidden, depth_out, + in_time, in_channels, point_cnn_padding, point_cnn_kernel_size, + point_cnn_params, point_cnn_stride, point_cnn_activation); + free(depth_out); + + // Pool. + in_time = out_time; + out_time = in_time - pool_kernel_size + 2 * pool_padding + 1; + avgpool1d(output_signal, out_time, point_out, + in_time, point_cnn_hidden, + pool_padding, pool_kernel_size, pool_stride, pool_activation); + free(point_out); + return 0; +} diff --git a/c_reference/src/rnn_bricked.c b/c_reference/src/rnn_bricked.c new file mode 100644 index 000000000..c09f7cf90 --- /dev/null +++ b/c_reference/src/rnn_bricked.c @@ -0,0 +1,303 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include "rnn_bricked.h" +#include "utils.h" + +// Forward Pass. +int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_first_brick) { + + // Buffers and params. + const BrickedFastGRNN_LR_Params* tparams = (const BrickedFastGRNN_LR_Params*)params; + + unsigned rnn_assign_offset = rnn_hidden, out_index = 0; + unsigned num_bricks = (in_time - window) / hop + 1; + // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden. + // This function only processes the forward context. + if (bi_direction) { + rnn_assign_offset <<= 1; + } + + // Compute W1 * W2 * X. + float* inputMulW = (float*)calloc(in_time * rnn_hidden, sizeof(float)); + float* tempLR = (float*)calloc(in_time * tparams->wRank, sizeof(float)); + float* hiddenState = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + float* preComp = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + transposed_tiledMatMul(input_signal, tparams->W1, in_time, in_dims, + tparams->wRank, in_dims, in_dims, + tempLR, tparams->block_size_w_to_lr); + transposed_tiledMatMul(tempLR, tparams->W2, in_time, tparams->wRank, + rnn_hidden, tparams->wRank, tparams->wRank, + inputMulW, tparams->block_size_w_from_lr); + free(tempLR); + // We can reuse the low-rank buffer from Wx to Uh, since Wx is computed at one stretch. + // memset is used. Hence, malloc can be used here for matMul result initialization. + tempLR = (float*)malloc(num_bricks * tparams->uRank * sizeof(float)); + for (unsigned t = 0; t < window; t++) { + // From higher dims to lower dims. + memset(tempLR, 0, num_bricks * tparams->uRank * sizeof(float)); + transposed_tiledMatMul(hiddenState, tparams->U1, num_bricks, rnn_hidden, + tparams->uRank, rnn_hidden, rnn_hidden, + tempLR, tparams->block_size_u_to_lr); + // From lower dims to higher dims. + // Add Wx with Uh. + // The tiled MatMuls are codes such that they yield result += matA * matB. + // Hence we use calloc and memset to equate the result to 0. + // But since we want Wx + Uh, we can store Wx and use the MatMul to add the result over the input. + float* preComp_offset = (float*)preComp; + for (unsigned n = 0; n < num_bricks; n++) { + float* inputMulW_offset = (float*)inputMulW + (n * hop + t) * rnn_hidden; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden %= 4; + while (len_unroll--) { + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + } + #endif + + while (hidden--) { + *preComp_offset++ = *inputMulW_offset++; + } + } + transposed_tiledMatMul(tempLR, tparams->U2, num_bricks, tparams->uRank, + rnn_hidden, tparams->uRank, tparams->uRank, + preComp, tparams->block_size_u_from_lr); + + // Apply the gating. + float* hiddenState_offset = (float*)hiddenState; + preComp_offset = (float*)preComp; + unsigned bricks = num_bricks; + while (bricks--) { + float* gateBias = (float*)tparams->Bg; + float* hiddenBias = (float*)tparams->Bh; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden = rnn_hidden % 4; + float gate, update; + while (len_unroll--) { + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + #endif + + while (hidden--) { + float gate = sigmoid((*preComp_offset) + (*gateBias++)); + float update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + } + // Sample first block if necessary. + if (sample_first_brick) { + if (t % hop == 0) { + memcpy(output_signal + (out_index++) * rnn_assign_offset, + hiddenState, rnn_hidden * sizeof(float)); + } + } + } + if (bi_direction) { + // If bi-directional then a gap would need to be left for the backward outputs. + float* hiddenState_offset = hiddenState; + for (unsigned n = 0; n < num_bricks; n++) { + memcpy(output_signal + (out_index++) * rnn_assign_offset, + hiddenState_offset, rnn_hidden * sizeof(float)); + hiddenState_offset += rnn_hidden; + } + } + else { + // If only forward is needed, the the whole block of memory can be copied without the loop. + memcpy(output_signal + out_index * rnn_assign_offset, + hiddenState, num_bricks * rnn_hidden * sizeof(float)); + } + free(hiddenState); + free(inputMulW); + free(preComp); + free(tempLR); + return 0; +} + +// Backward Pass. +int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_last_brick) { + + // Buffers and params. + const BrickedFastGRNN_LR_Params* tparams = (const BrickedFastGRNN_LR_Params*)params; + + unsigned rnn_assign_offset = rnn_hidden; + unsigned num_bricks = (in_time - window) / hop + 1; + unsigned out_index = in_time / hop; + // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden. + // This function only processes the forward context. + if (bi_direction) { + rnn_assign_offset <<= 1; + } + + // Compute W1 * W2 * X. + float* inputMulW = (float*)calloc(in_time * rnn_hidden, sizeof(float)); + float* tempLR = (float*)calloc(in_time * tparams->wRank, sizeof(float)); + float* hiddenState = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + float* preComp = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + transposed_tiledMatMul(input_signal, tparams->W1, in_time, in_dims, + tparams->wRank, in_dims, in_dims, + tempLR, tparams->block_size_w_to_lr); + transposed_tiledMatMul(tempLR, tparams->W2, in_time, tparams->wRank, + rnn_hidden, tparams->wRank, tparams->wRank, + inputMulW, tparams->block_size_w_from_lr); + free(tempLR); + // We can reuse the low-rank buffer from Wx to Uh, since Wx is computed at one stretch. + tempLR = (float*)calloc(num_bricks * tparams->uRank, sizeof(float)); + for (int t = window - 1; t >= 0; t--) { + // From higher dims to lower dims. + memset(tempLR, 0, num_bricks * tparams->uRank * sizeof(float)); + transposed_tiledMatMul(hiddenState, tparams->U1, num_bricks, rnn_hidden, + tparams->uRank, rnn_hidden, rnn_hidden, + tempLR, tparams->block_size_u_to_lr); + // From lower dims to higher dims. + // Add Wx with Uh. + // The tiled MatMuls are codes such that they yield result += matA * matB. + // Hence we use calloc and memset to equate the result to 0. + // But since we want Wx + Uh, we can store Wx and use the MatMul to add the result over the input. + float* preComp_offset = (float*)preComp; + for (unsigned n = 0; n < num_bricks; n++) { + float* inputMulW_offset = (float*)inputMulW + (n * hop + t) * rnn_hidden; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden %= 4; + while (len_unroll--) { + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + } + #endif + + while (hidden--) { + *preComp_offset++ = *inputMulW_offset++; + } + } + transposed_tiledMatMul(tempLR, tparams->U2, num_bricks, tparams->uRank, + rnn_hidden, tparams->uRank, tparams->uRank, + preComp, tparams->block_size_u_from_lr); + + // Apply the gating. + float* hiddenState_offset = (float*)hiddenState; + preComp_offset = (float*)preComp; + unsigned bricks = num_bricks; + while (bricks--) { + float* gateBias = (float*)tparams->Bg; + float* hiddenBias = (float*)tparams->Bh; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden = rnn_hidden % 4; + float gate, update; + while (len_unroll--) { + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + #endif + + while (hidden--) { + float gate = sigmoid((*preComp_offset) + (*gateBias++)); + float update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + } + // Sample first block if necessary. + if (sample_last_brick) { + if ((window - 1 - t) % hop == 0) { + // Iterate over the output in reverse. + memcpy(output_signal + (out_index--) * rnn_assign_offset, + hiddenState + (num_bricks - 1) * rnn_hidden, rnn_hidden * sizeof(float)); + } + } + } + // Since the all first (final in reverse) hiddenstates are calculated, we assign the whole block. + out_index = 0; + if (bi_direction) { + // If bi-directional then a gap would need to be left for the backward outputs. + float* hiddenState_offset = hiddenState; + for (unsigned n = 0; n < num_bricks; n++) { + memcpy(output_signal + (out_index++) * rnn_assign_offset, + hiddenState_offset, rnn_hidden * sizeof(float)); + hiddenState_offset += rnn_hidden; + } + } + else { + // If only forward is needed, the the whole block of memory can be copied without the loop. + memcpy(output_signal + out_index * rnn_assign_offset, + hiddenState, num_bricks * rnn_hidden * sizeof(float)); + } + free(hiddenState); + free(inputMulW); + free(preComp); + free(tempLR); + return 0; +} diff --git a/c_reference/src/utils.c b/c_reference/src/utils.c index dc58a5fa0..259b1d85a 100644 --- a/c_reference/src/utils.c +++ b/c_reference/src/utils.c @@ -71,6 +71,128 @@ void matVec(const float* const mat, const float* const vec, } } +void offset_matVec_conv1d(const float* mat, const float* vec, + unsigned nrows, unsigned ncols, + unsigned row_stride, unsigned vec_stride, + unsigned depthwise, float* ret) { + + while (nrows--) { + // For depthwise, the vec(input) pointer is updated. + // Since each row of the mat corresponds to a separate channel index. + float* vec_offset = depthwise ? (float*)vec++ : (float*)vec; + float* mat_offset = (float*)mat; + float sum = 0.0f; + unsigned cols = ncols; + + #ifdef LOOP_UNROLL + unsigned len_unroll = cols >> 2; + cols %= 4; // ncols % 4. + while (len_unroll--) { + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + } + #endif + + while (cols--) { + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + } + *ret++ = sum; + mat += row_stride; + } +} + +void tiledMatMul_float(const float* const matA, const float* const matB, + unsigned nrows, unsigned ncommon, unsigned ncols, + unsigned total_comm_A, unsigned total_cols_B, + float* const ret, unsigned block_size) { + for (unsigned row = 0; row < nrows; row += block_size) { + unsigned row_block_size = (row + block_size < nrows) ? block_size : nrows - row; + for (unsigned col = 0; col < ncols; col += block_size) { + unsigned col_block_size = (col + block_size < ncols) ? block_size : ncols - col; + for (unsigned comm = 0; comm < ncommon; comm += block_size) { + unsigned comm_block_size = (comm + block_size < ncommon) ? block_size : ncommon - comm; + for (unsigned block_row = row; block_row < row + row_block_size; block_row++) { + float *ret_offset = (float *)ret + block_row * ncols + col; + for (unsigned block_col = col; block_col < col + col_block_size; block_col++) { + float sum = 0; + unsigned temp_block_size = comm_block_size; + const float *matA_offset = (const float*)matA + block_row * total_comm_A + comm; + const float *matB_offset = (const float*)matB + comm * total_cols_B + block_col; + + #ifdef LOOP_UNROLL + unsigned len_unroll = temp_block_size >> 2; + temp_block_size %= 4; // comm_block_size % 4. + while (len_unroll--) { + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + } + #endif + + while (temp_block_size--) { + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + } + *ret_offset++ += sum; + } + } + } + } + } +} + +void transposed_tiledMatMul(const float* const matA, const float* const matB, + unsigned nrows, unsigned ncommon, unsigned ncols, + unsigned total_comm_A, unsigned total_comm_B, + float* const ret, unsigned block_size) { + for (unsigned row = 0; row < nrows; row += block_size) { + unsigned row_block_size = (row + block_size < nrows) ? block_size : nrows - row; + for (unsigned col = 0; col < ncols; col += block_size) { + unsigned col_block_size = (col + block_size < ncols) ? block_size : ncols - col; + for (unsigned comm = 0; comm < ncommon; comm += block_size) { + unsigned comm_block_size = (comm + block_size < ncommon) ? block_size : ncommon - comm; + for (unsigned block_row = row; block_row < row + row_block_size; block_row++) { + float *ret_offset = (float *)ret + block_row * ncols + col; + for (unsigned block_col = col; block_col < col + col_block_size; block_col++) { + float sum = 0; + unsigned temp_block_size = comm_block_size; + const float *matA_offset = (const float*)matA + block_row * total_comm_A + comm; + const float *matB_offset = (const float*)matB + block_col * total_comm_B + comm; + + #ifdef LOOP_UNROLL + unsigned len_unroll = temp_block_size >> 2; + temp_block_size %= 4; // comm_block_size % 4. + while (len_unroll--) { + sum += (*matA_offset++) * (*matB_offset++); + sum += (*matA_offset++) * (*matB_offset++); + sum += (*matA_offset++) * (*matB_offset++); + sum += (*matA_offset++) * (*matB_offset++); + } + #endif + + while (temp_block_size--) { + sum += (*matA_offset++) * (*matB_offset++); + } + *ret_offset++ += sum; + } + } + } + } + } +} + void v_add(float scalar1, const float* const vec1, float scalar2, const float* const vec2, unsigned len, float* const ret) { @@ -120,3 +242,34 @@ void softmax(const float* const input, unsigned len, float* const ret) { for (unsigned i = 0; i < len; i++) ret[i] = expf(input[i] - offset); } + +void semi_sigmoid_tanh(float* output_signal, const float* const input_signal, + unsigned in_time, unsigned in_channels) { + unsigned time_step = 0; // Used to avoid index multiplication. + while (in_time--) { + unsigned pivot = in_channels >> 1; + float* input_sigmoid_offset = (float*)input_signal + time_step; + float* input_tanh_offset = (float*)input_signal + time_step + pivot; + + #ifdef LOOP_UNROLL + unsigned len_unroll = pivot >> 2; + pivot %= 4; + while (len_unroll--) { + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + } + #endif + + while (pivot--) { + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + } + time_step += in_channels; + } +} diff --git a/c_reference/tests/Makefile b/c_reference/tests/Makefile index 08f418286..4eb8c7d70 100644 --- a/c_reference/tests/Makefile +++ b/c_reference/tests/Makefile @@ -8,7 +8,11 @@ MODEL_DIR=../models SRC_DIR=../src IFLAGS = -I $(INCLUDE_DIR) -I $(MODEL_DIR) -all: test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast test_quantized_face_detection_sparse +all: test_fastgrnn_lr test_conv1d test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast test_quantized_face_detection_sparse test_rnn_bricked test_phoneme_det_cnn_rnn + +CONV1D_DIR=conv1d +test_conv1d: $(CONV1D_DIR)/test_conv1d.c $(SRC_DIR)/conv1d.o $(SRC_DIR)/utils.o + $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm FASTGRNN_DIR=fastgrnn test_fastgrnn_lr: $(FASTGRNN_DIR)/test_fastgrnn_lr.c $(SRC_DIR)/utils.o $(SRC_DIR)/fastgrnn.o $(SRC_DIR)/classifier.o @@ -38,10 +42,18 @@ test_quantized_face_detection_fast: $(FACE_DETECTION_DIR)/test_quantized_face_de test_quantized_face_detection_sparse: $(FACE_DETECTION_DIR)/test_quantized_face_detection_sparse.c $(SRC_DIR)/quantized_utils.o $(SRC_DIR)/quantized_fastgrnn.o $(SRC_DIR)/quantized_rnnpool.o $(SRC_DIR)/quantized_mbconv.o $(MODEL_DIR)/quantized_face_detection_sparse.o $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -Wno-unused-result -lm +RNNBRICKED_DIR=rnn_bricked +test_rnn_bricked: $(RNNBRICKED_DIR)/test_rnn_bricked.c $(SRC_DIR)/utils.o $(SRC_DIR)/rnn_bricked.o + $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm + +KWS_DIR=kws +test_phoneme_det_cnn_rnn: $(KWS_DIR)/test_phoneme_det_cnn_rnn.c $(SRC_DIR)/utils.o $(SRC_DIR)/conv1d.o $(SRC_DIR)/dscnn.o $(SRC_DIR)/rnn_bricked.o + $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm + .PHONY: clean cleanest clean: - rm -f *.o *.gch test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast test_quantized_face_detection_sparse + rm -f *.o *.gch test_fastgrnn_lr test_conv1d test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast test_quantized_face_detection_sparse test_rnn_bricked test_phoneme_det_cnn_rnn cleanest: clean rm *~ diff --git a/c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h b/c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h new file mode 100644 index 000000000..e9b3f68da --- /dev/null +++ b/c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d705c8b29a9eaf7255e15fb96314cc5b541d13e8a44494921fa0d00fbe46beee +size 39066 diff --git a/c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h b/c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h new file mode 100644 index 000000000..c936bb204 --- /dev/null +++ b/c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64e7cbd963bfe285df54cac1484b62560ddc8e9a0384392f033b92d3f3b3df1b +size 70398 diff --git a/c_reference/tests/conv1d/conv1d_regular/conv_param.h b/c_reference/tests/conv1d/conv1d_regular/conv_param.h new file mode 100644 index 000000000..6f2ca1edc --- /dev/null +++ b/c_reference/tests/conv1d/conv1d_regular/conv_param.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8de43d4e289ee507ee629a7a15adc4d914ab4f03f37eaf15bb1ecb8ccb97671c +size 108492 diff --git a/c_reference/tests/conv1d/test_conv1d.c b/c_reference/tests/conv1d/test_conv1d.c new file mode 100644 index 000000000..d42aea786 --- /dev/null +++ b/c_reference/tests/conv1d/test_conv1d.c @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include "conv1d.h" +#include "utils.h" + +#include "./conv1d_regular/conv_param.h" +#include "./conv1d_depthwise/conv_param_depth.h" +#include "./conv1d_lr/conv_param_lr.h" + +// Error Check. +void errorCheck(float* pred, float* label, unsigned out_time, int out_features) { + float error = 0, denom = 0; + for (unsigned t = 0; t < out_time; t++) { + for (unsigned d = 0; d < out_features; d++) { + error += ((pred[t * out_features + d] - label[t * out_features + d]) + * (pred[t * out_features + d] - label[t * out_features + d])); + denom += label[t * out_features + d] * label[t * out_features + d]; + } + } + // RMSE - Relative Mean Squared Error. + // The ratio of the Squared Error to the Squared Summation of the Signal. + float avg_error = error / (out_time * out_features), rmse = error / denom; + printf("Agg Squared Error: %f ; MSE: %f ; RMSE: %f\n", error, avg_error, rmse); +} + +void conv1d_check() { + ConvLayers_Params conv_params = { + .W = CONV1D_CONV_WEIGHT, + .B = CONV1D_CONV_BIAS, + .depthwise = 0, + }; + + float* pred = (float*)malloc(CONV1D_OUT_TIME * CONV1D_OUT_FEATURES * sizeof(float)); + conv1d(pred, CONV1D_OUT_TIME, CONV1D_OUT_FEATURES, CONV1D_INPUT, + CONV1D_IN_TIME, CONV1D_IN_FEATURES, CONV1D_PAD, CONV1D_FILT, + &conv_params, CONV1D_STRIDE, CONV1D_ACT); + + printf("Testing Regular Convolution\n"); + errorCheck(pred, CONV1D_OUTPUT, CONV1D_OUT_TIME, CONV1D_OUT_FEATURES); + free(pred); +} + +void conv1d_parallel_check() { + ConvLayers_Parallel_Params conv_params = { + .W = CONV1D_CONV_WEIGHT, + .B = CONV1D_CONV_BIAS, + .block_size = 100, + }; + + float* pred = (float*)malloc(CONV1D_OUT_TIME * CONV1D_OUT_FEATURES * sizeof(float)); + conv1d_parallel(pred, CONV1D_OUT_TIME, CONV1D_OUT_FEATURES, CONV1D_INPUT, + CONV1D_IN_TIME, CONV1D_IN_FEATURES, CONV1D_PAD, CONV1D_FILT, + &conv_params, CONV1D_STRIDE, CONV1D_ACT); + + printf("Testing Parallel Convolution\n"); + errorCheck(pred, CONV1D_OUTPUT, CONV1D_OUT_TIME, CONV1D_OUT_FEATURES); + free(pred); +} + +void conv1d_depth_check() { + ConvLayers_Params conv_params = { + .W = CONV1D_DEPTH_CONV_WEIGHT, + .B = CONV1D_DEPTH_CONV_BIAS, + .depthwise = 1, + }; + + float* pred = (float*)malloc(CONV1D_DEPTH_OUT_TIME * CONV1D_DEPTH_OUT_FEATURES + * sizeof(float)); + conv1d(pred, CONV1D_DEPTH_OUT_TIME, 0, CONV1D_DEPTH_INPUT, + CONV1D_DEPTH_IN_TIME, CONV1D_DEPTH_IN_FEATURES, CONV1D_DEPTH_PAD, CONV1D_DEPTH_FILT, + &conv_params, CONV1D_DEPTH_STRIDE, CONV1D_DEPTH_ACT); + + printf("Testing Depthwise Convolution\n"); + errorCheck(pred, CONV1D_DEPTH_OUTPUT, + CONV1D_DEPTH_OUT_TIME, CONV1D_DEPTH_OUT_FEATURES); + free(pred); +} + +void conv1d_lr_check() { + ConvLayers_LR_Params conv_params = { + .W1 = CONV1D_LR_CONV_W1, + .W2 = CONV1D_LR_CONV_W2, + .B = CONV1D_LR_CONV_BIAS, + .rank = CONV1D_LR_LOW_RANK + }; + + float* pred = (float*)malloc(CONV1D_LR_OUT_TIME + * CONV1D_LR_OUT_FEATURES * sizeof(float)); + conv1d_lr(pred, CONV1D_LR_OUT_TIME, CONV1D_LR_OUT_FEATURES, CONV1D_LR_INPUT, + CONV1D_LR_IN_TIME, CONV1D_LR_IN_FEATURES, CONV1D_LR_PAD, CONV1D_LR_FILT, + &conv_params, CONV1D_LR_STRIDE, CONV1D_LR_ACT); + + printf("Testing Low-Rank Convolution\n"); + errorCheck(pred, CONV1D_LR_OUTPUT, CONV1D_LR_OUT_TIME, CONV1D_LR_OUT_FEATURES); + free(pred); +} + +void conv1d_lr_parallel_check() { + ConvLayers_LR_Parallel_Params conv_params = { + .W1 = CONV1D_LR_CONV_W1, + .W2 = CONV1D_LR_CONV_W2, + .B = CONV1D_LR_CONV_BIAS, + .rank = CONV1D_LR_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + float* pred = (float*)malloc(CONV1D_LR_OUT_TIME + * CONV1D_LR_OUT_FEATURES * sizeof(float)); + conv1d_lr_parallel(pred, CONV1D_LR_OUT_TIME, CONV1D_LR_OUT_FEATURES, CONV1D_LR_INPUT, + CONV1D_LR_IN_TIME, CONV1D_LR_IN_FEATURES, CONV1D_LR_PAD, CONV1D_LR_FILT, + &conv_params, CONV1D_LR_STRIDE, CONV1D_LR_ACT); + + printf("Testing Low-Rank Parallel Convolution\n"); + errorCheck(pred, CONV1D_LR_OUTPUT, CONV1D_LR_OUT_TIME, CONV1D_LR_OUT_FEATURES); + free(pred); +} + +int main() { + #ifdef LOOP_UNROLL + printf("Loop Unrolling Active\n"); + #endif + conv1d_check(); + conv1d_parallel_check(); + conv1d_lr_check(); + conv1d_depth_check(); + conv1d_lr_parallel_check(); + return 0; +} diff --git a/c_reference/tests/kws/keyword_spotting_io_1.h b/c_reference/tests/kws/keyword_spotting_io_1.h new file mode 100644 index 000000000..18517f20e --- /dev/null +++ b/c_reference/tests/kws/keyword_spotting_io_1.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1faf461aaadd548a9c9c6a8b3b62552e1cc41d268bbb6a5f2b3abf5d9e0bc575 +size 326949 diff --git a/c_reference/tests/kws/keyword_spotting_io_2.h b/c_reference/tests/kws/keyword_spotting_io_2.h new file mode 100644 index 000000000..293d4b379 --- /dev/null +++ b/c_reference/tests/kws/keyword_spotting_io_2.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:804e9c5d4053f61b993486957a8e55c2e9577f52f6297a180b0430da73290709 +size 306578 diff --git a/c_reference/tests/kws/keyword_spotting_io_3.h b/c_reference/tests/kws/keyword_spotting_io_3.h new file mode 100644 index 000000000..f6efbb000 --- /dev/null +++ b/c_reference/tests/kws/keyword_spotting_io_3.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ccf1ba4ea4c53597f0c60f9a60e0031ceb45446d0c314944a1a9b6f286f1a921 +size 308150 diff --git a/c_reference/tests/kws/postcnn_params.h b/c_reference/tests/kws/postcnn_params.h new file mode 100644 index 000000000..9da921d22 --- /dev/null +++ b/c_reference/tests/kws/postcnn_params.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0adca3c3658d5860193ae73f62468742dbc2c1f8b6508abdb2efa9311111165d +size 1374545 diff --git a/c_reference/tests/kws/precnn_params.h b/c_reference/tests/kws/precnn_params.h new file mode 100644 index 000000000..fb1539736 --- /dev/null +++ b/c_reference/tests/kws/precnn_params.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79383c92269713907e1247ec39232f8b98e95850637ec5a1cce2a310ebf3e469 +size 520803 diff --git a/c_reference/tests/kws/rnn_params.h b/c_reference/tests/kws/rnn_params.h new file mode 100644 index 000000000..72a581918 --- /dev/null +++ b/c_reference/tests/kws/rnn_params.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77b63ebcde61cdec096e0be6d83d60c679c63a7608047317672c10d3e478e5fe +size 1302881 diff --git a/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c b/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c new file mode 100644 index 000000000..c9cbc6658 --- /dev/null +++ b/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include "conv1d.h" +#include "dscnn.h" +#include "utils.h" +#include "rnn_bricked.h" + +#include "keyword_spotting_io_2.h" +#include "precnn_params.h" +#include "rnn_params.h" +#include "postcnn_params.h" + +// Check number of output time-steps with the number of label time-steps. +int checkTime(unsigned out_time) { + if (out_time != KWS_OUT_TIME) { + printf("Error, estimated and actual ouput time-steps mismatch"); + return 1; + } + return 0; +} + +// Error Check. +void checkError(float* pred, float* label) { + float error = 0, denom = 0; + for (unsigned t = 0; t < KWS_OUT_TIME; t++) { + for (unsigned d = 0; d < POST_CNN_OUT_FEATURES; d++) { + error += ((pred[t * POST_CNN_OUT_FEATURES + d] + - label[t * POST_CNN_OUT_FEATURES + d]) + * (pred[t * POST_CNN_OUT_FEATURES + d] + - label[t * POST_CNN_OUT_FEATURES + d])); + denom += label[t * POST_CNN_OUT_FEATURES + d] + * label[t * POST_CNN_OUT_FEATURES + d]; + } + } + printf("Full Network\n"); + printf("Agg Squared Error : %f\n", error); + printf("MSE : %f\n", error / (KWS_OUT_TIME*POST_CNN_OUT_FEATURES)); + // RMSE - Relative Mean Squared Error. + // The ratio of the Squared Error to the Squared Summation of the Signal. + printf("RMSE : %f\n", error / denom); +} + +/* CNN-RNN based Phoneme Detection Model. + + The phoneme detection model used consists of 6 blocks. + 1st block is a CNN, where kernel size is 5 and regular tanh activation. + 2nd block is an RNN, which has a specified forward and a backward context running at a stride/hop of 3. + Hence it reduces the sequence length by a factor of 3. + Rest of the blocks(3rd, 4th, 5th and 6th) are a combination of CNNs. + Each of the final 4 blocks consist of a depth-CNN (kernel size of 5) and a point-CNN (kernel size of 1). + + Input to the architecture is of the form (seq_len, feature_dim) where feature dim refers to n_mels (number of mel features/number of features from the featurizer). + Output is of the form (seq_len/3, 41) where 41 is the number of phonemes over which the classification is performed. + Phonemes are predicted for every 3rd time frame, operating under the assumption that they don't vary faster than that. + + NOTE: Before deployment for real-time streaming applications, we would need to make minor modification. + These changes are subject to the input specs i.e fixing input buffer time steps, number of features from the deployed featurizer, method of reading the input into a buffer. +*/ +void phoneme_prediction(float* mem_buf) { + ConvLayers_LR_Parallel_Params conv_params = { + .W1 = CNN1_W1, + .W2 = CNN1_W2, + .B = CNN1_BIAS, + .rank = PRE_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + ConvLayers_Params depth_param_2 = { + .W = CNN2_DEPTH_W, + .B = CNN2_DEPTH_BIAS, + .depthwise = 1, + }; + + ConvLayers_LR_Parallel_Params point_param_2 = { + .W1 = CNN2_POINT_W1, + .W2 = CNN2_POINT_W2, + .B = CNN2_POINT_BIAS, + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + ConvLayers_Params depth_param_3 = { + .W = CNN3_DEPTH_W, + .B = CNN3_DEPTH_BIAS, + .depthwise = 1, + }; + + ConvLayers_LR_Parallel_Params point_param_3 = { + .W1 = CNN3_POINT_W1, + .W2 = CNN3_POINT_W2, + .B = CNN3_POINT_BIAS, + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + ConvLayers_Params depth_param_4 = { + .W = CNN4_DEPTH_W, + .B = CNN4_DEPTH_BIAS, + .depthwise = 1, + }; + + ConvLayers_LR_Parallel_Params point_param_4 = { + .W1 = CNN4_POINT_W1, + .W2 = CNN4_POINT_W2, + .B = CNN4_POINT_BIAS, + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + ConvLayers_Params depth_param_5 = { + .W = CNN5_DEPTH_W, + .B = CNN5_DEPTH_BIAS, + .depthwise = 1, + }; + + ConvLayers_LR_Parallel_Params point_param_5 = { + .W1 = CNN5_POINT_W1, + .W2 = CNN5_POINT_W2, + .B = CNN5_POINT_BIAS, + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + BrickedFastGRNN_LR_Params bwd_RNN_params = { + .W1 = B_W1, + .W2 = B_W2, + .wRank = RNN_LOW_RANK, + .U1 = B_U1, + .U2 = B_U2, + .uRank = RNN_LOW_RANK, + .Bg = B_BIAS_GATE, + .Bh = B_BIAS_UPDATE, + .sigmoid_zeta = sigmoid(B_ZETA), + .sigmoid_nu = sigmoid(B_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, + }; + + BrickedFastGRNN_LR_Params fwd_RNN_params = { + .W1 = F_W1, + .W2 = F_W2, + .wRank = RNN_LOW_RANK, + .U1 = F_U1, + .U2 = F_U2, + .uRank = RNN_LOW_RANK, + .Bg = F_BIAS_GATE, + .Bh = F_BIAS_UPDATE, + .sigmoid_zeta = sigmoid(F_ZETA), + .sigmoid_nu = sigmoid(F_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, + }; + + unsigned in_time, out_time; + + /* Pre-CNN. */ + in_time = KWS_IN_TIME; + out_time = in_time - PRE_CNN_FILT + (PRE_CNN_FILT_PAD << 1) + 1; + float* cnn1_out = (float*)malloc(out_time * PRE_CNN_OUT_FEATURES * sizeof(float)); + // Since batchnorm1d is the first layer and in-place will alter the input. + // Use the in-place computation only if the input can be discarded/altered. Else avoid in-place computation for this layer. + phon_pred_lr_cnn(cnn1_out, mem_buf, + conv1d_lr_parallel, in_time, PRE_CNN_IN_FEATURES, + 0, 0, PRE_CNN_BNORM_AFFINE, CNN1_SCALE, CNN1_OFFSET, PRE_CNN_BNORM_INPLACE, + PRE_CNN_OUT_FEATURES, PRE_CNN_FILT_PAD, PRE_CNN_FILT, + &conv_params, PRE_CNN_STRIDE, PRE_CNN_FILT_ACT); // Regular tanh activation. + + batchnorm1d(0, cnn1_out, in_time, RNN_IN_FEATURES, + 0, 0, RNN_BNORM_AFFINE, RNN_SCALE, RNN_OFFSET, 1, 0.00001); + + /* Bricked Bi-FastGRNN Block. */ + out_time = in_time/RNN_HOP + 1; + float* rnn_out = (float*)malloc(out_time * RNN_OUT_FEATURES * sizeof(float)); + forward_bricked_fastgrnn_lr(rnn_out, RNN_OUT_FEATURES >> 1, cnn1_out, + in_time, RNN_IN_FEATURES, RNN_FWD_WINDOW, RNN_HOP, + &fwd_RNN_params, RNN_BI_DIR, RNN_SAMPLE_FIRST_BRICK); + + backward_bricked_fastgrnn_lr(rnn_out + (RNN_OUT_FEATURES >> 1), + RNN_OUT_FEATURES >> 1, cnn1_out, + in_time, RNN_IN_FEATURES, RNN_BWD_WINDOW, RNN_HOP, + &bwd_RNN_params, RNN_BI_DIR, RNN_SAMPLE_LAST_BRICK); + free(cnn1_out); + + /* Post-CNN. */ + // Since all inputs to the subsequent layers are temporary, in-place batchnorm1d can be used without any input(initial buffer)/output(final layer) data alteration/corruption. + // CNN2. + in_time = out_time; + out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; + out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; + float* cnn2_out = (float*)malloc(out_time * POST_CNN_INTER_FEATURES * sizeof(float)); + phon_pred_depth_point_lr_cnn(cnn2_out, rnn_out, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, + 0, 0, POST_CNN_BNORM_AFFINE, CNN2_SCALE, CNN2_OFFSET, POST_CNN_BNORM_INPLACE, + POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, + &depth_param_2, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, + POST_CNN_INTER_FEATURES, POST_CNN_POINT_PAD, POST_CNN_POINT_FILT, + &point_param_2, POST_CNN_POINT_STRIDE, POST_CNN_POINT_ACT, + POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); + free(rnn_out); + + // CNN3. + in_time = out_time; + out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; + out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; + float* cnn3_out = (float*)malloc(out_time * POST_CNN_INTER_FEATURES * sizeof(float)); + phon_pred_depth_point_lr_cnn(cnn3_out, cnn2_out, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, + 0, 0, POST_CNN_BNORM_AFFINE, CNN3_SCALE, CNN3_OFFSET, POST_CNN_BNORM_INPLACE, + POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, + &depth_param_3, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, + POST_CNN_INTER_FEATURES, POST_CNN_POINT_PAD, POST_CNN_POINT_FILT, + &point_param_3, POST_CNN_POINT_STRIDE, POST_CNN_POINT_ACT, + POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); + free(cnn2_out); + + // CNN4. + in_time = out_time; + out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; + out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; + float* cnn4_out = (float*)malloc(out_time * POST_CNN_INTER_FEATURES * sizeof(float)); + phon_pred_depth_point_lr_cnn(cnn4_out, cnn3_out, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, + 0, 0, POST_CNN_BNORM_AFFINE, CNN4_SCALE, CNN4_OFFSET, POST_CNN_BNORM_INPLACE, + POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, + &depth_param_4, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, + POST_CNN_INTER_FEATURES, POST_CNN_POINT_PAD, POST_CNN_POINT_FILT, + &point_param_4, POST_CNN_POINT_STRIDE, POST_CNN_POINT_ACT, + POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); + free(cnn3_out); + + // CNN5. + in_time = out_time; + out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; + out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; + float* pred = (float*)malloc(out_time * POST_CNN_OUT_FEATURES * sizeof(float)); + phon_pred_depth_point_lr_cnn(pred, cnn4_out, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, + 0, 0, POST_CNN_BNORM_AFFINE, CNN5_SCALE, CNN5_OFFSET, POST_CNN_BNORM_INPLACE, + POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, + &depth_param_5, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, + POST_CNN_OUT_FEATURES, POST_CNN_POINT_PAD, POST_CNN_POINT_FILT, + &point_param_5, POST_CNN_POINT_STRIDE, POST_CNN_POINT_ACT, + POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); + free(cnn4_out); + + /* Output Time and Prediction Check. Created for Debugging. */ + if (checkTime(out_time)) + return; + else + checkError(pred, OUTPUT); + free(pred); +} + +int main() { + #ifdef LOOP_UNROLL + printf("Loop Unrolling Active\n"); + #endif + clock_t begin = clock(); + phoneme_prediction(INPUT); + clock_t end = clock(); + double time_spent = (float)(end - begin) / CLOCKS_PER_SEC; + printf("Time elapsed is %f seconds\n", time_spent); + return 0; +} diff --git a/c_reference/tests/rnn_bricked/rnn_bricked_io.h b/c_reference/tests/rnn_bricked/rnn_bricked_io.h new file mode 100644 index 000000000..a6d90e301 --- /dev/null +++ b/c_reference/tests/rnn_bricked/rnn_bricked_io.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fc3181b35c0cfa858a5ff6415f4ad793915b2fce5c792e741a8f8e04b095349 +size 1908996 diff --git a/c_reference/tests/rnn_bricked/rnn_params.h b/c_reference/tests/rnn_bricked/rnn_params.h new file mode 100644 index 000000000..17060301a --- /dev/null +++ b/c_reference/tests/rnn_bricked/rnn_params.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f062722fe7b91fbb5f78af631d4c594d2b2af61b8dfdd266b799d39c133afa2 +size 1293672 diff --git a/c_reference/tests/rnn_bricked/test_rnn_bricked.c b/c_reference/tests/rnn_bricked/test_rnn_bricked.c new file mode 100644 index 000000000..b2f03696d --- /dev/null +++ b/c_reference/tests/rnn_bricked/test_rnn_bricked.c @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include "rnn_bricked.h" +#include "utils.h" + +#include "rnn_params.h" +#include "rnn_bricked_io.h" + +int main() { + + BrickedFastGRNN_LR_Params bwd_RNN_params = { + .W1 = B_W1, + .W2 = B_W2, + .wRank = RNN_LOW_RANK, + .U1 = B_U1, + .U2 = B_U2, + .uRank = RNN_LOW_RANK, + .Bg = B_BIAS_GATE, + .Bh = B_BIAS_UPDATE, + .sigmoid_zeta = sigmoid(B_ZETA), + .sigmoid_nu = sigmoid(B_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, + }; + + BrickedFastGRNN_LR_Params fwd_RNN_params = { + .W1 = F_W1, + .W2 = F_W2, + .wRank = RNN_LOW_RANK, + .U1 = F_U1, + .U2 = F_U2, + .uRank = RNN_LOW_RANK, + .Bg = F_BIAS_GATE, + .Bh = F_BIAS_UPDATE, + .sigmoid_zeta = sigmoid(F_ZETA), + .sigmoid_nu = sigmoid(F_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, + }; + + float* pred = (float*)malloc(RNN_OUT_TIME * RNN_OUT_FEATURES * sizeof(float)); + + forward_bricked_fastgrnn_lr(pred, RNN_OUT_FEATURES >> 1, INPUT, + RNN_IN_TIME, RNN_IN_FEATURES, FWD_WINDOW, HOP, + &fwd_RNN_params, 1, 1); + + backward_bricked_fastgrnn_lr(pred + (RNN_OUT_FEATURES >> 1), RNN_OUT_FEATURES >> 1, INPUT, + RNN_IN_TIME, RNN_IN_FEATURES, BWD_WINDOW, HOP, + &bwd_RNN_params, 1, 1); + + float error = 0; + float denom = 0; + for (int t = 0; t < RNN_OUT_TIME; t++) { + for (int d = 0; d < RNN_OUT_FEATURES; d++) { + error += ((pred[t * RNN_OUT_FEATURES + d] - OUTPUT[t * RNN_OUT_FEATURES + d]) + * (pred[t * RNN_OUT_FEATURES + d] - OUTPUT[t * RNN_OUT_FEATURES + d])); + denom += OUTPUT[t * RNN_OUT_FEATURES + d] * OUTPUT[t * RNN_OUT_FEATURES + d]; + } + } + // RMSE - Relative Mean Squared Error. + // The ratio of the Squared Error to the Squared Summation of the Signal. + float avg_error = error / (RNN_OUT_TIME * RNN_OUT_FEATURES); + float rmse = error / denom; + + #ifdef LOOP_UNROLL + printf("Loop Unrolling Active\n"); + #endif + printf("Testing Bricked RNNs Bi-Directional\n"); + printf("Agg Squared Error: %f ; MSE: %f ; RMSE: %f\n", error, avg_error, rmse); + free(pred); + return 0; +}