Skip to content

Commit

Permalink
[quant] Quantized Average Pool Refactoring (pytorch#42009)
Browse files Browse the repository at this point in the history
Summary:
**cc** z-a-f. Refactor `qavg_pool(2,3)d_nhwc_kernel` as mentioned in pytorch#40316.

# Benchmarks
## Python
Before | After
![before_after](https://user-images.githubusercontent.com/37529096/88401550-fea7ba80-ce1d-11ea-81c5-3ae912e81e8f.png)
## C++
![before_after_cpp](https://user-images.githubusercontent.com/37529096/88401845-5ba37080-ce1e-11ea-9bf2-3c95ac2b4b49.png)
## Notes
- It does seem that for `qint8` and `quint8` there is a noticeable 2x increase in speed at least when the `channels > 64` in the benchmarks.
## Reproduce
### Python
```
import time
import numpy as np
import torch
from termcolor import colored
def time_avg_pool2d(X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override, iterations):
    X, (scale, zero_point, torch_type) = X
    qX_nchw = torch.quantize_per_tensor(torch.from_numpy(X), scale=scale,
                                    zero_point=zero_point, dtype=torch_type)
    qX_nhwc = qX_nchw.contiguous(memory_format=torch.channels_last)
    assert(qX_nhwc.stride() != sorted(qX_nhwc.stride()))
    assert(qX_nchw.is_contiguous(memory_format=torch.contiguous_format))
    assert(qX_nhwc.is_contiguous(memory_format=torch.channels_last))
    start = time.time()
    for _ in range(iterations):
        X_hat = torch.nn.quantized.functional.avg_pool2d(qX_nchw, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode,
                count_include_pad=count_include_pad, divisor_override=divisor_override)
    qnchw_end = time.time() - start
    start = time.time()
    for _ in range(iterations):
        X_hat = torch.nn.quantized.functional.avg_pool2d(qX_nhwc, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode,
                count_include_pad=count_include_pad, divisor_override=divisor_override)
    qnhwc_end = time.time() - start
    return qnchw_end*1000/iterations, qnhwc_end*1000/iterations

def time_avg_pool3d(X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override,  iterations):
    X, (scale, zero_point, torch_type) = X
    qX_ncdhw = torch.quantize_per_tensor(torch.from_numpy(X), scale=scale,
                                    zero_point=zero_point, dtype=torch_type)
    qX_ndhwc = qX_ncdhw.contiguous(memory_format=torch.channels_last_3d)
    assert(qX_ndhwc.stride() != sorted(qX_ndhwc.stride()))
    assert(qX_ncdhw.is_contiguous(memory_format=torch.contiguous_format))
    assert(qX_ndhwc.is_contiguous(memory_format=torch.channels_last_3d))
    start = time.time()
    for _ in range(iterations):
        X_hat = torch.nn.quantized.functional.avg_pool3d(qX_ncdhw, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode,
                count_include_pad=count_include_pad, divisor_override=divisor_override)
    qncdhw_end = time.time() - start
    start = time.time()
    for _ in range(iterations):
        X_hat = torch.nn.quantized.functional.avg_pool3d(qX_ndhwc, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode,
                count_include_pad=count_include_pad, divisor_override=divisor_override)
    qndhwc_end = time.time() - start
    return qncdhw_end*1000/iterations, qndhwc_end*1000/iterations

iterations = 10000
print("iterations = {}".format(iterations))
print("Benchmark", "Time(ms)", sep="\t\t\t\t\t")
for torch_type in (torch.qint8, torch.quint8, torch.qint32):
    for channel in (4,8,64,256):
        X = np.random.rand(1, channel, 56, 56).astype(np.float32), (0.5, 1, torch_type)
        ts = time_avg_pool2d(X, 4, None, 0, True, True, None, iterations)
        print(colored("avg_pool2d({}, {}, {})".format(str(torch_type), channel, "nchw"), 'green'), colored(ts[0], 'yellow'), sep="\t")
        print(colored("avg_pool2d({}, {}, {})".format(str(torch_type), channel, "nhwc"), 'green'), colored(ts[1], 'yellow'), sep="\t")
for torch_type in (torch.qint8, torch.quint8, torch.qint32):
    for channel in (4,8,64,256):
        X = np.random.rand(1, channel, 56, 56, 4).astype(np.float32), (0.5, 1, torch_type)
        ts = time_avg_pool3d(X, 4, None, 0, True, True, None, iterations)
        print(colored("avg_pool3d({}, {}, {})".format(str(torch_type), channel, "ncdhw"), 'green'), colored(ts[0], 'yellow'), sep="\t")
        print(colored("avg_pool3d({}, {}, {})".format(str(torch_type), channel, "ndhwc"), 'green'), colored(ts[1], 'yellow'), sep="\t")
```
### C++
1. `git clone https://github.com/google/benchmark.git`
2. `git clone https://github.com/google/googletest.git benchmark/googletest`

```
# CMakeLists.txt
cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(time_avg_pool VERSION 0.1.0)

find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_subdirectory(benchmark)

add_executable(time_average_pool time_average_pool.cpp)
target_link_libraries(time_average_pool ${TORCH_LIBRARIES})
set_property(TARGET time_average_pool PROPERTY CXX_STANDARD 14)
target_link_libraries(time_average_pool benchmark::benchmark)
```

```
// time_average_pool.cpp
#include <benchmark/benchmark.h>
#include <torch/torch.h>

torch::Device device(torch::kCPU);

static void BM_TORCH_QAVG_POOL2D_NCHW_SINGLE_THREADED(benchmark::State& state) {
  torch::init_num_threads();
  torch::set_num_threads(1);
  auto x_nchw = torch::rand({1, state.range(0), 56, 56}, device);
  auto qx_nchw = torch::quantize_per_tensor(x_nchw, 0.5, 1, torch::kQUInt8);
  torch::Tensor X_hat;
  for (auto _ : state)
    X_hat = torch::nn::functional::avg_pool2d(
        qx_nchw,
        torch::nn::AvgPool2dOptions({4, 4}).ceil_mode(true).count_include_pad(
            true));
}

static void BM_TORCH_QAVG_POOL2D_NHWC_SINGLE_THREADED(benchmark::State& state) {
  torch::init_num_threads();
  torch::set_num_threads(1);
  auto x_nchw = torch::rand({1, state.range(0), 56, 56}, device);
  auto qx_nchw = torch::quantize_per_tensor(x_nchw, 0.5, 1, torch::kQUInt8);
  auto qx_nhwc = qx_nchw.contiguous(torch::MemoryFormat::ChannelsLast);
  torch::Tensor X_hat;
  for (auto _ : state)
    X_hat = torch::nn::functional::avg_pool2d(
        qx_nhwc,
        torch::nn::AvgPool2dOptions({4, 4}).ceil_mode(true).count_include_pad(
            true));
}

static void BM_TORCH_QAVG_POOL2D_NCHW(benchmark::State& state) {
  auto x_nchw = torch::rand({1, state.range(0), 56, 56}, device);
  auto qx_nchw = torch::quantize_per_tensor(x_nchw, 0.5, 1, torch::kQUInt8);
  torch::Tensor X_hat;
  for (auto _ : state)
    X_hat = torch::nn::functional::avg_pool2d(
        qx_nchw,
        torch::nn::AvgPool2dOptions({4, 4}).ceil_mode(true).count_include_pad(
            true));
}

static void BM_TORCH_QAVG_POOL2D_NHWC(benchmark::State& state) {
  auto x_nchw = torch::rand({1, state.range(0), 56, 56}, device);
  auto qx_nchw = torch::quantize_per_tensor(x_nchw, 0.5, 1, torch::kQUInt8);
  auto qx_nhwc = qx_nchw.contiguous(torch::MemoryFormat::ChannelsLast);
  torch::Tensor X_hat;
  for (auto _ : state)
    X_hat = torch::nn::functional::avg_pool2d(
        qx_nhwc,
        torch::nn::AvgPool2dOptions({4, 4}).ceil_mode(true).count_include_pad(
            true));
}

static void BM_TORCH_QAVG_POOL3D_NCDHW_SINGLE_THREADED(
    benchmark::State& state) {
  torch::init_num_threads();
  torch::set_num_threads(1);
  auto x_ncdhw = torch::rand({1, state.range(0), 56, 56, 4}, device);
  auto qx_ncdhw = torch::quantize_per_tensor(x_ncdhw, 0.5, 1, torch::kQUInt8);
  torch::Tensor X_hat;
  for (auto _ : state)
    X_hat = torch::nn::functional::avg_pool3d(
        qx_ncdhw,
        torch::nn::AvgPool3dOptions({5, 5, 5})
            .ceil_mode(true)
            .count_include_pad(true));
}

static void BM_TORCH_QAVG_POOL3D_NDHWC_SINGLE_THREADED(
    benchmark::State& state) {
  torch::init_num_threads();
  torch::set_num_threads(1);
  auto x_ncdhw = torch::rand({1, state.range(0), 56, 56, 4}, device);
  auto qx_ncdhw = torch::quantize_per_tensor(x_ncdhw, 0.5, 1, torch::kQUInt8);
  auto qx_ndhwc = qx_ncdhw.contiguous(torch::MemoryFormat::ChannelsLast3d);
  torch::Tensor X_hat;
  for (auto _ : state)
    X_hat = torch::nn::functional::avg_pool3d(
        qx_ndhwc,
        torch::nn::AvgPool3dOptions({5, 5, 5})
            .ceil_mode(true)
            .count_include_pad(true));
}

static void BM_TORCH_QAVG_POOL3D_NCDHW(benchmark::State& state) {
  auto x_ncdhw = torch::rand({1, state.range(0), 56, 56, 4}, device);
  auto qx_ncdhw = torch::quantize_per_tensor(x_ncdhw, 0.5, 1, torch::kQUInt8);
  torch::Tensor X_hat;
  for (auto _ : state)
    X_hat = torch::nn::functional::avg_pool3d(
        qx_ncdhw,
        torch::nn::AvgPool3dOptions({5, 5, 5})
            .ceil_mode(true)
            .count_include_pad(true));
}

static void BM_TORCH_QAVG_POOL3D_NDHWC(benchmark::State& state) {
  auto x_ncdhw = torch::rand({1, state.range(0), 56, 56, 4}, device);
  auto qx_ncdhw = torch::quantize_per_tensor(x_ncdhw, 0.5, 1, torch::kQUInt8);
  auto qx_ndhwc = qx_ncdhw.contiguous(torch::MemoryFormat::ChannelsLast3d);
  torch::Tensor X_hat;
  for (auto _ : state)
    X_hat = torch::nn::functional::avg_pool3d(
        qx_ndhwc,
        torch::nn::AvgPool3dOptions({5, 5, 5})
            .ceil_mode(true)
            .count_include_pad(true));
}

BENCHMARK(BM_TORCH_QAVG_POOL2D_NCHW)->RangeMultiplier(8)->Range(4, 256);
BENCHMARK(BM_TORCH_QAVG_POOL2D_NHWC)->RangeMultiplier(8)->Range(4, 256);
BENCHMARK(BM_TORCH_QAVG_POOL3D_NCDHW)->RangeMultiplier(8)->Range(4, 256);
BENCHMARK(BM_TORCH_QAVG_POOL3D_NDHWC)->RangeMultiplier(8)->Range(4, 256);
BENCHMARK(BM_TORCH_QAVG_POOL2D_NCHW_SINGLE_THREADED)
    ->RangeMultiplier(8)
    ->Range(4, 256);
BENCHMARK(BM_TORCH_QAVG_POOL2D_NHWC_SINGLE_THREADED)
    ->RangeMultiplier(8)
    ->Range(4, 256);
BENCHMARK(BM_TORCH_QAVG_POOL3D_NCDHW_SINGLE_THREADED)
    ->RangeMultiplier(8)
    ->Range(4, 256);
BENCHMARK(BM_TORCH_QAVG_POOL3D_NDHWC_SINGLE_THREADED)
    ->RangeMultiplier(8)
    ->Range(4, 256);
BENCHMARK_MAIN();
```

3. `mkdir build && cd build`
4. ```cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` .. ```
5. `cmake --build . --config Release`
6. `./time_average_pool`

# Further notes
- I've used `istrideB, istrideD, istrideH, strideW, strideC` to match `_qadaptive_avg_pool_kernel` since there's some code duplication there as mentioned in pytorch#40316.

Pull Request resolved: pytorch#42009

Reviewed By: pbelevich

Differential Revision: D22794441

Pulled By: z-a-f

fbshipit-source-id: 16710202811a1fbe1c99ea4d9b45876d6d28a8da
  • Loading branch information
thinking-tower authored and facebook-github-bot committed Aug 6, 2020
1 parent 9add11f commit 04d7e16
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 118 deletions.
219 changes: 107 additions & 112 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1441,107 +1441,8 @@ void qadaptive_avg_pool3d_ndhwc_kernel(
istrideW);
}

void qavg_pool2d_nhwc_kernel(
const Tensor& qx,
Tensor& qy,
int64_t b,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t outputWidth,
int64_t outputHeight,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool2d_nhwc", [&]() {
scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());
int64_t batch_size = nInputPlane * inputWidth * inputHeight;
auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(
idata + b * batch_size);

// lift these operations outside the loop to reduce access overheads
float input_scale = qx.q_scale();
float output_scale = qy.q_scale();
int input_zero_point = qx.q_zero_point();
int output_zero_point = qy.q_zero_point();
int64_t divisor_override_factor =
divisor_override.has_value() ? divisor_override.value() : 0;

for (int64_t oh = 0; oh < outputHeight; oh++) {
for (int64_t ow = 0; ow < outputWidth; ow++) {
auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(
odata + b * nInputPlane * outputWidth * outputHeight +
(oh * outputWidth + ow) * nInputPlane);
int64_t hstart = oh * dH - padH;
int64_t wstart = ow * dW - padW;
int64_t hend = std::min(hstart + kH, inputHeight + padH);
int64_t wend = std::min(wstart + kW, inputWidth + padW);
int64_t pool_size = (hend - hstart) * (wend - wstart);
hstart = std::max(hstart, (int64_t)0);
wstart = std::max(wstart, (int64_t)0);
hend = std::min(hend, inputHeight);
wend = std::min(wend, inputWidth);

int size = (hend - hstart) * (wend - wstart);
int divide_size = count_include_pad ? pool_size : size;
int divide_factor =
divisor_override_factor ? divisor_override_factor : divide_size;
float multiplier = input_scale / output_scale / divide_factor;
int input_zero_point_m_size = -input_zero_point * size;

int64_t c = 0;
// For int8 quantization, we implicitly use int32 as accumulation
// Or else, it will go to the slow path
// TODO: support 16bit, 32bit, and etc.
do_avg_pool_on_AVX2<scalar_t>(
i_p,
o_p,
c,
nInputPlane,
nInputPlane,
input_zero_point_m_size,
output_zero_point,
multiplier,
0,
1,
hstart,
hend,
wstart,
wend,
1,
1,
inputWidth,
1);
// 1) The following loop handles the remaining channels
// 2) It also handles the Non-AVX2 path
for (; c < nInputPlane; ++c) {
int32_t acc_int32 = input_zero_point_m_size;
int64_t tcntr = 0;
for (int64_t ih = hstart; ih < hend; ih++) {
for (int64_t iw = wstart; iw < wend; iw++) {
tcntr = ih * inputWidth + iw;
auto val = *(i_p + tcntr * nInputPlane + c);
acc_int32 += val;
}
}
double acc_fp = acc_int32 * 1.0;
// clamp
o_p[c] = at::native::quantize_val<scalar_t>(
1.0f / multiplier, output_zero_point, acc_fp)
.val_;
} // c
} // ow
} // oh
});
}

void qavg_pool3d_nhwc_kernel(
void _qavg_pool_nhwc_kernel(
const std::string& fn_name,
const Tensor& qx,
Tensor& qy,
int64_t b,
Expand All @@ -1563,12 +1464,19 @@ void qavg_pool3d_nhwc_kernel(
int padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool3d_nhwc", [&]() {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), fn_name, [&]() {
scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());
int batch_size = nInputPlane * inputWidth * inputHeight * inputDepth;
auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(
idata + b * batch_size);
int strideC = 1;
int strideW = strideC * nInputPlane;
int istrideH = strideW * inputWidth;
int istrideD = istrideH * inputHeight;
int istrideB = istrideD * inputDepth;
int ostrideH = strideW * outputWidth;
int ostrideD = ostrideH * outputHeight;
int ostrideB = ostrideD * outputDepth;
auto* i_p =
reinterpret_cast<typename scalar_t::underlying*>(idata + b * istrideB);

// lift these operations outside the loop to reduce access overheads
float input_scale = qx.q_scale();
Expand All @@ -1582,10 +1490,8 @@ void qavg_pool3d_nhwc_kernel(
for (int oh = 0; oh < outputHeight; oh++) {
for (int ow = 0; ow < outputWidth; ow++) {
auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(
odata +
b * nInputPlane * outputWidth * outputHeight * outputDepth +
(od * outputHeight * outputWidth + oh * outputWidth + ow) *
nInputPlane);
odata + b * ostrideB + od * ostrideD + oh * ostrideH +
ow * strideW);
int dstart = od * dD - padD;
int hstart = oh * dH - padH;
int wstart = ow * dW - padW;
Expand Down Expand Up @@ -1636,12 +1542,12 @@ void qavg_pool3d_nhwc_kernel(
// 2) It also handles the Non-AVX2 path
for (int c = c_start; c < nInputPlane; ++c) {
int32_t acc_int32 = input_zero_point_m_size;
int64_t tcntr = 0;
for (int64_t id = dstart; id < dend; id++) {
for (int64_t ih = hstart; ih < hend; ih++) {
for (int64_t iw = wstart; iw < wend; iw++) {
tcntr = id * inputHeight * inputWidth + ih * inputWidth + iw;
auto val = *(i_p + tcntr * nInputPlane + c);
auto val =
*(i_p + id * istrideD + ih * istrideH + iw * strideW +
c * strideC);
acc_int32 += val;
}
}
Expand All @@ -1658,6 +1564,95 @@ void qavg_pool3d_nhwc_kernel(
});
}

void qavg_pool2d_nhwc_kernel(
const Tensor& qx,
Tensor& qy,
int64_t b,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t outputWidth,
int64_t outputHeight,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
_qavg_pool_nhwc_kernel(
"avg_pool2d_nhwc",
qx,
qy,
b,
nInputPlane,
inputWidth,
inputHeight,
1,
outputWidth,
outputHeight,
1,
kW,
kH,
1,
dW,
dH,
1,
padW,
padH,
0,
count_include_pad,
divisor_override);
}

void qavg_pool3d_nhwc_kernel(
const Tensor& qx,
Tensor& qy,
int64_t b,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t inputDepth,
int64_t outputWidth,
int64_t outputHeight,
int64_t outputDepth,
int kW,
int kH,
int kD,
int dW,
int dH,
int dD,
int padW,
int padH,
int padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
_qavg_pool_nhwc_kernel(
"avg_pool3d_nhwc",
qx,
qy,
b,
nInputPlane,
inputWidth,
inputHeight,
inputDepth,
outputWidth,
outputHeight,
outputDepth,
kW,
kH,
kD,
dW,
dH,
dD,
padW,
padH,
padD,
count_include_pad,
divisor_override);
}

template <typename T>
int64_t do_quantized_bilinear_on_AVX2(
const typename T::underlying*& pos1,
Expand Down
12 changes: 6 additions & 6 deletions test/quantization/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ def test_avg_pool2d(self, X, kernel, stride, padding, ceil_mode, count_include_p
dtype=torch_type)

self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), atol=1.0, rtol=0,
msg=error_message.format(name, qX_hat.int_repr(), qX_ref.int_repr()))
msg=error_message.format(name, qX_ref.int_repr(), qX_hat.int_repr()))
self.assertEqual(scale, qX_hat.q_scale(),
msg=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
self.assertEqual(zero_point, qX_hat.q_zero_point(),
Expand Down Expand Up @@ -1117,7 +1117,7 @@ def test_avg_pool2d_nhwc(self, X, kernel, stride, padding, ceil_mode, count_incl
dtype=torch_type)

self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), atol=1.0, rtol=0,
msg=error_message.format(name, X_hat.int_repr(), qX_ref.int_repr()))
msg=error_message.format(name, qX_ref.int_repr(), X_hat.int_repr()))
self.assertEqual(scale, X_hat.q_scale(),
msg=error_message.format(name + '.scale', scale, X_hat.q_scale()))
self.assertEqual(zero_point, X_hat.q_zero_point(),
Expand Down Expand Up @@ -1169,7 +1169,7 @@ def test_avg_pool3d(self, X, kernel, stride, padding, ceil_mode, count_include_p
qX_ref = torch.quantize_per_tensor(X_ref, scale=qX_hat.q_scale(), zero_point=qX_hat.q_zero_point(),
dtype=torch_type)
self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), atol=1.0, rtol=0,
msg=error_message.format(name, qX_hat.int_repr(), qX_ref.int_repr()))
msg=error_message.format(name, qX_ref.int_repr(), qX_hat.int_repr()))
self.assertEqual(scale, qX_hat.q_scale(),
msg=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
self.assertEqual(zero_point, qX_hat.q_zero_point(),
Expand Down Expand Up @@ -1233,7 +1233,7 @@ def test_avg_pool3d_nhwc(self, X, kernel, stride, padding, ceil_mode, count_incl
dtype=torch_type)

self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), atol=1.0, rtol=0,
msg=error_message.format(name, X_hat.int_repr(), qX_ref.int_repr()))
msg=error_message.format(name, qX_ref.int_repr(), X_hat.int_repr()))
self.assertEqual(scale, X_hat.q_scale(),
msg=error_message.format(name + '.scale', scale, X_hat.q_scale()))
self.assertEqual(zero_point, X_hat.q_zero_point(),
Expand Down Expand Up @@ -1290,7 +1290,7 @@ def test_adaptive_avg_pool2d_nhwc(self, X, output_size_h, output_size_w):
self.assertTrue(X_hat.stride() != sorted(X_hat.stride()))
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
self.assertEqualIgnoreType(X_ref, X_hat.int_repr(), atol=1.0, rtol=0,
msg="{} results are off".format(name))
msg=error_message.format(name, X_ref, X_hat.int_repr()))
self.assertEqual(scale, X_hat.q_scale(),
msg=error_message.format(name + '.scale', scale, X_hat.q_scale()))
self.assertEqual(zero_point, X_hat.q_zero_point(),
Expand Down Expand Up @@ -1416,7 +1416,7 @@ def test_adaptive_avg_pool3d_ndhwc(self, X, output_size_d, output_size_h,
self.assertTrue(X_hat.stride() != sorted(X_hat.stride()))
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
self.assertEqualIgnoreType(X_ref, X_hat.int_repr(), atol=1.0, rtol=0,
msg="{} results are off".format(name))
msg=error_message.format(name, X_ref, X_hat.int_repr()))
self.assertEqual(scale, X_hat.q_scale(),
msg=error_message.format(name + '.scale', scale, X_hat.q_scale()))
self.assertEqual(zero_point, X_hat.q_zero_point(),
Expand Down

0 comments on commit 04d7e16

Please sign in to comment.