Skip to content

Commit

Permalink
[pytorch][PR] Optimize qavg_pool3d_nhwc (pytorch#35740)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#35740

For one of the quantized CV model, the avg_pool3d operation is more than 6x slower than C2 implementation. The reason behind this comes from the following aspects:
- function access inside the loop (such as ```q_scale()``` and ```q_zero_point()```)
- additional data copy in ```Vec256::store``` and ```at::quantize_vec```

This diff resolves the above issue with the following measures:
- lift function access outside the loops
- add an 8-lane path in ```QuantizeAvx2``` to replace ```at::quantize_vec```
- in addition, interchanges c-loop to the innermost for better memory locality.

Test Plan:
buck test //caffe2/test:quantized

Performance Before (n x h x w x c = 4 x 56 x 56 x ch):
```
type            c=2             c=4             c=15            c=24            c=48            c=128           c=256
torch.qint8     903.08 us       1373.39 us      2297.97 us      636.72 us       864.98 us       1618.72 us      2908.47 us
torch.quint8    911.93 us       1429.39 us      2315.59 us      623.08 us       844.17 us       1522.28 us      2711.08 us
torch.qint32    897.77 us       1346.97 us      3846.41 us      6211.92 us      11977.74 us     34348.23 us     62927.48 us
```
Performance After:
```
type            c=2             c=4             c=15            c=24            c=48            c=128           c=256
torch.qint8     123.29 us       176.00 us       348.90 us       99.02 us        132.73 us       267.17 us       513.43 us
torch.quint8    123.76 us       171.90 us       338.17 us       97.92 us        131.06 us       260.09 us       521.16 us
torch.qint32    102.97 us       172.57 us       559.31 us       814.03 us       1606.11 us      4164.89 us      10041.52 us
```

Reviewed By: lly-zero-one

Differential Revision: D20711888

fbshipit-source-id: a71dd55639500f4a036eee96c357737cff9d33db
  • Loading branch information
Di Wu authored and facebook-github-bot committed Apr 2, 2020
1 parent 0f99b28 commit c4f56e9
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 73 deletions.
63 changes: 63 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_qint.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ __m256i pack_saturate_and_clamp(
T min_val,
T max_val);

template <>
__m256i pack_saturate_and_clamp<int32_t>(
__m256i first,
__m256i second,
int32_t min_val,
int32_t max_val) {
// This function is for linkage only, will not be used
AT_ERROR("pack_saturate_and_clamp<int32_t> is not supported");
}

template <>
__m256i pack_saturate_and_clamp<int8_t>(
__m256i first,
Expand Down Expand Up @@ -95,10 +105,47 @@ inline void __attribute__((always_inline)) QuantizeAvx2(
constexpr int VLEN = 8;
constexpr auto min_val = std::numeric_limits<typename T::underlying>::min();
constexpr auto max_val = std::numeric_limits<typename T::underlying>::max();
const __m256i min_v = _mm256_set1_epi32(min_val);
const __m256i max_v = _mm256_set1_epi32(max_val);
int i = 0;
__m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
static const __m256i shuffle_mask_v = _mm256_set_epi8(
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0x0c,
0x08,
0x04,
0x00,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0xff,
0x0c,
0x08,
0x04,
0x00);
__m256i permute_mask_v =
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
__m256i permute_mask_l8_v =
_mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
int len_aligned = len / (VLEN * 4) * (VLEN * 4);
for (; i < len_aligned; i += 4 * VLEN) {
// x
Expand Down Expand Up @@ -133,6 +180,22 @@ inline void __attribute__((always_inline)) QuantizeAvx2(
_mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v);
}

// Additional 8-lane AVX2 version to take advantage when len is smaller
// based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM)
for (; i < len / VLEN * VLEN; i += VLEN) {
__m256 x_vals = _mm256_load_ps(src + i);
__m256 x_transformed_v =
_mm256_fmadd_ps(x_vals, inverse_scale_v, _mm256_set1_ps(zero_point));
__m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
__m256i x_clipped_v =
_mm256_max_epi32(min_v, _mm256_min_epi32(max_v, x_rounded_v));

x_clipped_v = _mm256_shuffle_epi8(x_clipped_v, shuffle_mask_v);
x_clipped_v = _mm256_permutevar8x32_epi32(x_clipped_v, permute_mask_l8_v);
_mm_storel_epi64(
reinterpret_cast<__m128i*>(dst + i), _mm256_castsi256_si128(x_clipped_v));
}

for (; i < len; ++i) {
float transformed = zero_point + src[i] * inverse_scale;
float clipped =
Expand Down
Loading

0 comments on commit c4f56e9

Please sign in to comment.