diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index a0f58faf9aee00..e1b84d8de84855 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -1441,107 +1441,8 @@ void qadaptive_avg_pool3d_ndhwc_kernel( istrideW); } -void qavg_pool2d_nhwc_kernel( - const Tensor& qx, - Tensor& qy, - int64_t b, - int64_t nInputPlane, - int64_t inputWidth, - int64_t inputHeight, - int64_t outputWidth, - int64_t outputHeight, - int kW, - int kH, - int dW, - int dH, - int padW, - int padH, - bool count_include_pad, - c10::optional divisor_override) { - AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool2d_nhwc", [&]() { - scalar_t* idata = static_cast(qx.data_ptr()); - scalar_t* odata = static_cast(qy.data_ptr()); - int64_t batch_size = nInputPlane * inputWidth * inputHeight; - auto* i_p = reinterpret_cast( - idata + b * batch_size); - - // lift these operations outside the loop to reduce access overheads - float input_scale = qx.q_scale(); - float output_scale = qy.q_scale(); - int input_zero_point = qx.q_zero_point(); - int output_zero_point = qy.q_zero_point(); - int64_t divisor_override_factor = - divisor_override.has_value() ? divisor_override.value() : 0; - - for (int64_t oh = 0; oh < outputHeight; oh++) { - for (int64_t ow = 0; ow < outputWidth; ow++) { - auto* o_p = reinterpret_cast( - odata + b * nInputPlane * outputWidth * outputHeight + - (oh * outputWidth + ow) * nInputPlane); - int64_t hstart = oh * dH - padH; - int64_t wstart = ow * dW - padW; - int64_t hend = std::min(hstart + kH, inputHeight + padH); - int64_t wend = std::min(wstart + kW, inputWidth + padW); - int64_t pool_size = (hend - hstart) * (wend - wstart); - hstart = std::max(hstart, (int64_t)0); - wstart = std::max(wstart, (int64_t)0); - hend = std::min(hend, inputHeight); - wend = std::min(wend, inputWidth); - - int size = (hend - hstart) * (wend - wstart); - int divide_size = count_include_pad ? pool_size : size; - int divide_factor = - divisor_override_factor ? divisor_override_factor : divide_size; - float multiplier = input_scale / output_scale / divide_factor; - int input_zero_point_m_size = -input_zero_point * size; - - int64_t c = 0; - // For int8 quantization, we implicitly use int32 as accumulation - // Or else, it will go to the slow path - // TODO: support 16bit, 32bit, and etc. - do_avg_pool_on_AVX2( - i_p, - o_p, - c, - nInputPlane, - nInputPlane, - input_zero_point_m_size, - output_zero_point, - multiplier, - 0, - 1, - hstart, - hend, - wstart, - wend, - 1, - 1, - inputWidth, - 1); - // 1) The following loop handles the remaining channels - // 2) It also handles the Non-AVX2 path - for (; c < nInputPlane; ++c) { - int32_t acc_int32 = input_zero_point_m_size; - int64_t tcntr = 0; - for (int64_t ih = hstart; ih < hend; ih++) { - for (int64_t iw = wstart; iw < wend; iw++) { - tcntr = ih * inputWidth + iw; - auto val = *(i_p + tcntr * nInputPlane + c); - acc_int32 += val; - } - } - double acc_fp = acc_int32 * 1.0; - // clamp - o_p[c] = at::native::quantize_val( - 1.0f / multiplier, output_zero_point, acc_fp) - .val_; - } // c - } // ow - } // oh - }); -} - -void qavg_pool3d_nhwc_kernel( +void _qavg_pool_nhwc_kernel( + const std::string& fn_name, const Tensor& qx, Tensor& qy, int64_t b, @@ -1563,12 +1464,19 @@ void qavg_pool3d_nhwc_kernel( int padD, bool count_include_pad, c10::optional divisor_override) { - AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool3d_nhwc", [&]() { + AT_DISPATCH_QINT_TYPES(qx.scalar_type(), fn_name, [&]() { scalar_t* idata = static_cast(qx.data_ptr()); scalar_t* odata = static_cast(qy.data_ptr()); - int batch_size = nInputPlane * inputWidth * inputHeight * inputDepth; - auto* i_p = reinterpret_cast( - idata + b * batch_size); + int strideC = 1; + int strideW = strideC * nInputPlane; + int istrideH = strideW * inputWidth; + int istrideD = istrideH * inputHeight; + int istrideB = istrideD * inputDepth; + int ostrideH = strideW * outputWidth; + int ostrideD = ostrideH * outputHeight; + int ostrideB = ostrideD * outputDepth; + auto* i_p = + reinterpret_cast(idata + b * istrideB); // lift these operations outside the loop to reduce access overheads float input_scale = qx.q_scale(); @@ -1582,10 +1490,8 @@ void qavg_pool3d_nhwc_kernel( for (int oh = 0; oh < outputHeight; oh++) { for (int ow = 0; ow < outputWidth; ow++) { auto* o_p = reinterpret_cast( - odata + - b * nInputPlane * outputWidth * outputHeight * outputDepth + - (od * outputHeight * outputWidth + oh * outputWidth + ow) * - nInputPlane); + odata + b * ostrideB + od * ostrideD + oh * ostrideH + + ow * strideW); int dstart = od * dD - padD; int hstart = oh * dH - padH; int wstart = ow * dW - padW; @@ -1636,12 +1542,12 @@ void qavg_pool3d_nhwc_kernel( // 2) It also handles the Non-AVX2 path for (int c = c_start; c < nInputPlane; ++c) { int32_t acc_int32 = input_zero_point_m_size; - int64_t tcntr = 0; for (int64_t id = dstart; id < dend; id++) { for (int64_t ih = hstart; ih < hend; ih++) { for (int64_t iw = wstart; iw < wend; iw++) { - tcntr = id * inputHeight * inputWidth + ih * inputWidth + iw; - auto val = *(i_p + tcntr * nInputPlane + c); + auto val = + *(i_p + id * istrideD + ih * istrideH + iw * strideW + + c * strideC); acc_int32 += val; } } @@ -1658,6 +1564,95 @@ void qavg_pool3d_nhwc_kernel( }); } +void qavg_pool2d_nhwc_kernel( + const Tensor& qx, + Tensor& qy, + int64_t b, + int64_t nInputPlane, + int64_t inputWidth, + int64_t inputHeight, + int64_t outputWidth, + int64_t outputHeight, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + bool count_include_pad, + c10::optional divisor_override) { + _qavg_pool_nhwc_kernel( + "avg_pool2d_nhwc", + qx, + qy, + b, + nInputPlane, + inputWidth, + inputHeight, + 1, + outputWidth, + outputHeight, + 1, + kW, + kH, + 1, + dW, + dH, + 1, + padW, + padH, + 0, + count_include_pad, + divisor_override); +} + +void qavg_pool3d_nhwc_kernel( + const Tensor& qx, + Tensor& qy, + int64_t b, + int64_t nInputPlane, + int64_t inputWidth, + int64_t inputHeight, + int64_t inputDepth, + int64_t outputWidth, + int64_t outputHeight, + int64_t outputDepth, + int kW, + int kH, + int kD, + int dW, + int dH, + int dD, + int padW, + int padH, + int padD, + bool count_include_pad, + c10::optional divisor_override) { + _qavg_pool_nhwc_kernel( + "avg_pool3d_nhwc", + qx, + qy, + b, + nInputPlane, + inputWidth, + inputHeight, + inputDepth, + outputWidth, + outputHeight, + outputDepth, + kW, + kH, + kD, + dW, + dH, + dD, + padW, + padH, + padD, + count_include_pad, + divisor_override); +} + template int64_t do_quantized_bilinear_on_AVX2( const typename T::underlying*& pos1, diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index 0da178c01cf008..860589ec69c6ae 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -1055,7 +1055,7 @@ def test_avg_pool2d(self, X, kernel, stride, padding, ceil_mode, count_include_p dtype=torch_type) self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), atol=1.0, rtol=0, - msg=error_message.format(name, qX_hat.int_repr(), qX_ref.int_repr())) + msg=error_message.format(name, qX_ref.int_repr(), qX_hat.int_repr())) self.assertEqual(scale, qX_hat.q_scale(), msg=error_message.format(name + '.scale', scale, qX_hat.q_scale())) self.assertEqual(zero_point, qX_hat.q_zero_point(), @@ -1117,7 +1117,7 @@ def test_avg_pool2d_nhwc(self, X, kernel, stride, padding, ceil_mode, count_incl dtype=torch_type) self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), atol=1.0, rtol=0, - msg=error_message.format(name, X_hat.int_repr(), qX_ref.int_repr())) + msg=error_message.format(name, qX_ref.int_repr(), X_hat.int_repr())) self.assertEqual(scale, X_hat.q_scale(), msg=error_message.format(name + '.scale', scale, X_hat.q_scale())) self.assertEqual(zero_point, X_hat.q_zero_point(), @@ -1169,7 +1169,7 @@ def test_avg_pool3d(self, X, kernel, stride, padding, ceil_mode, count_include_p qX_ref = torch.quantize_per_tensor(X_ref, scale=qX_hat.q_scale(), zero_point=qX_hat.q_zero_point(), dtype=torch_type) self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), atol=1.0, rtol=0, - msg=error_message.format(name, qX_hat.int_repr(), qX_ref.int_repr())) + msg=error_message.format(name, qX_ref.int_repr(), qX_hat.int_repr())) self.assertEqual(scale, qX_hat.q_scale(), msg=error_message.format(name + '.scale', scale, qX_hat.q_scale())) self.assertEqual(zero_point, qX_hat.q_zero_point(), @@ -1233,7 +1233,7 @@ def test_avg_pool3d_nhwc(self, X, kernel, stride, padding, ceil_mode, count_incl dtype=torch_type) self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), atol=1.0, rtol=0, - msg=error_message.format(name, X_hat.int_repr(), qX_ref.int_repr())) + msg=error_message.format(name, qX_ref.int_repr(), X_hat.int_repr())) self.assertEqual(scale, X_hat.q_scale(), msg=error_message.format(name + '.scale', scale, X_hat.q_scale())) self.assertEqual(zero_point, X_hat.q_zero_point(), @@ -1290,7 +1290,7 @@ def test_adaptive_avg_pool2d_nhwc(self, X, output_size_h, output_size_w): self.assertTrue(X_hat.stride() != sorted(X_hat.stride())) # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(X_ref, X_hat.int_repr(), atol=1.0, rtol=0, - msg="{} results are off".format(name)) + msg=error_message.format(name, X_ref, X_hat.int_repr())) self.assertEqual(scale, X_hat.q_scale(), msg=error_message.format(name + '.scale', scale, X_hat.q_scale())) self.assertEqual(zero_point, X_hat.q_zero_point(), @@ -1416,7 +1416,7 @@ def test_adaptive_avg_pool3d_ndhwc(self, X, output_size_d, output_size_h, self.assertTrue(X_hat.stride() != sorted(X_hat.stride())) # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(X_ref, X_hat.int_repr(), atol=1.0, rtol=0, - msg="{} results are off".format(name)) + msg=error_message.format(name, X_ref, X_hat.int_repr())) self.assertEqual(scale, X_hat.q_scale(), msg=error_message.format(name + '.scale', scale, X_hat.q_scale())) self.assertEqual(zero_point, X_hat.q_zero_point(),