Skip to content

Commit

Permalink
[quant][fx] Support some default ops in the native backend config (py…
Browse files Browse the repository at this point in the history
…torch#74600)

Summary:
Pull Request resolved: pytorch#74600

Following pytorch#74210, this PR adds the support for some ops
using the DefaultNodeQuantizeHandler in the backend_config_dict defintion for pytorch native backend

TODO: There is still a few ops we didn't handle with backend_config_dict path: gelu and softmax, need to discuss if we still need them, if so we can change the test
to use backend_config_dict and remove the DefaultNodeQuantizeHandler after that

Test Plan:
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: andrewor14

Differential Revision: D35071437

fbshipit-source-id: 70351d2810ca1ac7dc09d4a9c239f6757ccb51ca
(cherry picked from commit 5e68f75)
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Mar 25, 2022
1 parent 797fa26 commit b347b8c
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 94 deletions.
46 changes: 35 additions & 11 deletions test/quantization/fx/test_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
extract_shadow_logger_info,
extend_logger_results_with_comparison,
)
from torch.ao.quantization.fx.backend_config import get_native_backend_config_dict
from torch.ao.quantization.fx.backend_config.utils import get_pattern_to_quantize_handlers


# Note: these models are not for use outside of this file. While it's good
Expand Down Expand Up @@ -274,7 +276,19 @@ def _wrapped_sigmoid(x):
def _wrapped_linear(x, w, b):
return F.linear(x, w, b)


def get_all_quant_patterns():
""" we are in the process to migrate the frontend of fx graph mode quant
to use backend_config_dict, so some of the patterns are moved to backend_config_dict
this function will include these patterns so that we can still have all the patterns
"""
# TODO: we can remove this call, and get all patterns from backend_config_dict in
# the future when the frontend refactor is done in fx graph mode quantization
all_quant_patterns = get_default_quant_patterns()
# some of the patterns are moved to (native) backend_config_dict so we need to
# add them back here
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config_dict()).items():
all_quant_patterns[pattern] = quantize_handler
return all_quant_patterns

class TestFXGraphMatcher(QuantizationTestCase):

Expand Down Expand Up @@ -542,7 +556,6 @@ def forward(self, x):
self.assert_types_for_matched_subgraph_pairs(
results, expected_types, m1p, m2p)


def test_op_relationship_mapping(self):
"""
Tests that the mapping of op relationships is complete.
Expand Down Expand Up @@ -620,7 +633,7 @@ def _op_is_unmatchable(op):
op in METHS_UNMATCHABLE
)

default_quant_patterns = get_default_quant_patterns()
default_quant_patterns = get_all_quant_patterns()
for pattern, qhandler_cls in default_quant_patterns.items():
base_op = None
if isinstance(pattern, tuple):
Expand Down Expand Up @@ -664,9 +677,6 @@ def _op_is_unmatchable(op):
# RNNDynamicQuantizeHandler
pass
elif qhandler_cls == qp.DefaultNodeQuantizeHandler:
# torch.sum does not have quantized equivalents
if base_op == torch.sum:
continue
self.assertTrue(
_op_in_base_sets_of_related_ops(base_op),
f"{base_op} not in sets of related ops")
Expand All @@ -682,8 +692,14 @@ def _op_is_unmatchable(op):
_op_in_base_sets_of_related_ops(base_op),
f"{base_op} not in sets of related ops")
else:
raise AssertionError(
f"handing for {qhandler_cls} not implemented")
# torch.sum does not have quantized equivalents
if base_op == torch.sum:
continue
# didn't match explicit quantize handler class, we can check if the
# operator is in the related op set directly
if not _op_in_base_sets_of_related_ops(base_op):
raise AssertionError(
f"handing for {qhandler_cls} for op {base_op} not implemented")

@skipIfNoFBGEMM
def test_user_defined_function(self):
Expand Down Expand Up @@ -1534,7 +1550,7 @@ def test_op_io_dtype_coverage(self):

# 4. go through the ops mapped to each QuantizeHandler type, and verify
# correctness.
default_quant_patterns = get_default_quant_patterns()
default_quant_patterns = get_all_quant_patterns()
for pattern, qhandler_cls in default_quant_patterns.items():
base_op = None
if isinstance(pattern, tuple):
Expand Down Expand Up @@ -1591,8 +1607,16 @@ def test_op_io_dtype_coverage(self):
# embedding shadowing is not implemented, for now
continue
else:
raise AssertionError(
f"handing for {qhandler_cls} not implemented")
if qhandler_cls(None, {}).is_general_tensor_value_op():
self.assertTrue(
(base_op in FUNS_IO_TYPE_FP32_OR_INT8) or
(base_op in MODS_IO_TYPE_FP32_OR_INT8) or
(base_op in METHS_IO_TYPE_FP32_OR_INT8),
f"missing IO type handling for {base_op} using {qhandler_cls}")
else:
self.assertTrue(
(base_op in FUNS_IO_TYPE_FP32) or (base_op in MODS_IO_TYPE_FP32),
f"missing IO type handling for {base_op} using {qhandler_cls}")

@skipIfNoFBGEMM
def test_user_defined_function(self):
Expand Down
12 changes: 8 additions & 4 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4073,10 +4073,13 @@ class DummyQuant3():
default_affine_fixed_qparams_fake_quant)
self._assertFixedQParamsFakeQuantizeEqual(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"],
default_symmetric_fixed_qparams_fake_quant)
self.assertTrue(get_default_output_activation_post_process_map(is_training=True) is
DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP)
self.assertTrue(get_default_output_activation_post_process_map(is_training=False) is
DEFAULT_OUTPUT_OBSERVER_MAP)
output_fake_quantize_map = get_default_output_activation_post_process_map(is_training=True)
output_observer_map = get_default_output_activation_post_process_map(is_training=False)
self.assertEqual(output_observer_map.get("dummy_quant3"), default_symmetric_fixed_qparams_observer)
self._assertFixedQParamsFakeQuantizeEqual(output_fake_quantize_map.get("dummy_quant3"),
default_symmetric_fixed_qparams_fake_quant)



def test_reuse_input_qconfig(self):
class M1(torch.nn.Module):
Expand Down Expand Up @@ -5412,6 +5415,7 @@ def test_gelu_reference(self):
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize')
]
# TODO: change these to use backend_config_dict
additional_patterns = {torch.nn.GELU: DefaultNodeQuantizeHandler,
torch.nn.functional.gelu: DefaultNodeQuantizeHandler}
self._test_default_node_quant_handler_ops(
Expand Down
8 changes: 8 additions & 0 deletions torch/ao/ns/fx/pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from torch.ao.quantization.utils import getattr_from_fqn
from .ns_types import NSNodeTargetType
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns
from torch.ao.quantization.fx.backend_config import get_native_backend_config_dict
from torch.ao.quantization.fx.backend_config.utils import get_pattern_to_quantize_handlers
from torch.ao.quantization import (
ObserverBase,
FakeQuantizeBase,
Expand Down Expand Up @@ -66,7 +68,13 @@ def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
# * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
# For fusions, we only care about patterns composed of multiple ops.
# TODO(future PR): allow customizations from default patterns.
# TODO: we can remove this call, and get all patterns from backend_config_dict in
# the future when the frontend refactor is done in fx graph mode quantization
all_quant_patterns = get_default_quant_patterns()
# some of the patterns are moved to (native) backend_config_dict so we need to
# add them back here
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config_dict()).items():
all_quant_patterns[pattern] = quantize_handler
default_base_op_idx = 0
for quant_pattern, _quant_handler in all_quant_patterns.items():
# Only patterns of multiple ops are fusions, ignore
Expand Down
125 changes: 94 additions & 31 deletions torch/ao/quantization/fx/backend_config/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,106 @@
from .observation_type import ObservationType
import torch.nn.qat as nnqat

def get_native_backend_config_dict():
""" Get backend for PyTorch Native backend_config_dict (fbgemm/qnnpack)
"""
# dtype configs

# weighted op int8 config
# activation: quint8, weight: qint8, bias: float
weighted_op_int8_dtype_config = {
# optional, input activation dtype
"input_dtype": torch.quint8,
# optional, weight dtype
"weight_dtype": torch.qint8,
# optional, bias dtype
"bias_dtype": torch.float,
# optional, output activation dtype
"output_dtype": torch.quint8
}
# operator (module/functional/torch ops) configs
linear_module_config = {
# Please see README under this folder for pattern format
"pattern": torch.nn.Linear,

def _get_default_op_backend_config(op, dtype_configs):
return {
"pattern": op,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_int8_dtype_config,
],
# the root module for the pattern, used to query the reference quantized module
# e.g. for a (torch.nn.ReLU, torch.nn.Linear) pattern, the root will be torch.nn.Linear
"root_module": torch.nn.Linear,
# the corresponding reference quantized module for the root module
"reference_quantized_module_for_root": torch.nn.quantized._reference.Linear,
"qat_module": nnqat.Linear,
"dtype_configs": dtype_configs,
}

# START dtype configs

# weighted op int8 dtype config
# this is config for ops that has quantized weights, like linear, conv
weighted_op_int8_dtype_config = {
# optional, input activation dtype
"input_dtype": torch.quint8,
# optional, weight dtype
"weight_dtype": torch.qint8,
# optional, bias dtype
"bias_dtype": torch.float,
# optional, output activation dtype
"output_dtype": torch.quint8
}

default_op_quint8_dtype_config = {
# optional, input activation dtype
"input_dtype": torch.quint8,
# optional, output activation dtype
"output_dtype": torch.quint8,
}

default_op_fp16_dtype_config = {
# optional, input activation dtype
"input_dtype": torch.float16,
# optional, weight dtype
"weight_dtype": torch.float16,
# optional, output activation dtype
"output_dtype": torch.float16,
}
# END dtype configs

# operator (module/functional/torch ops) configs
_LINEAR_MODULE_CONFIG = {
# Please see README under this folder for pattern format
"pattern": torch.nn.Linear,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_int8_dtype_config,
],
# the root module for the pattern, used to query the reference quantized module
# e.g. for a (torch.nn.ReLU, torch.nn.Linear) pattern, the root will be torch.nn.Linear
"root_module": torch.nn.Linear,
# the corresponding reference quantized module for the root module
"reference_quantized_module_for_root": torch.nn.quantized._reference.Linear,
"qat_module": nnqat.Linear,
}

_DEFAULT_OP_INT8_CONFIGS = [
_get_default_op_backend_config(op, [default_op_quint8_dtype_config]) for op in [
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ELU,
torch.nn.LeakyReLU,
torch.nn.Hardswish,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.Dropout,
torch.nn.functional.elu,
torch.nn.functional.hardswish,
torch.nn.functional.instance_norm,
torch.nn.functional.leaky_relu,
torch.nn.functional.dropout,
]]

_DEFAULT_OP_FP16_CONFIGS = [
_get_default_op_backend_config(op, [default_op_fp16_dtype_config]) for op in [
torch.nn.SiLU,
torch.nn.Mish,
torch.nn.functional.silu,
torch.nn.functional.mish,
torch.sum,
]]

_DEFAULT_OP_INT8_OR_FP16_CONFIGS = [
_get_default_op_backend_config(op, [default_op_quint8_dtype_config, default_op_fp16_dtype_config]) for op in [
torch.nn.LayerNorm,
torch.nn.functional.layer_norm,
]]

def get_native_backend_config_dict():
""" Get backend for PyTorch Native backend_config_dict (fbgemm/qnnpack)
"""
return {
# optional
"name": "native",
"configs": [
linear_module_config,
_LINEAR_MODULE_CONFIG,
*_DEFAULT_OP_INT8_CONFIGS,
*_DEFAULT_OP_FP16_CONFIGS,
*_DEFAULT_OP_INT8_OR_FP16_CONFIGS,
],
}
10 changes: 5 additions & 5 deletions torch/ao/quantization/fx/pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..fake_quantize import FixedQParamsFakeQuantize
# from .quantization_patterns import BinaryOpQuantizeHandler
from ..observer import ObserverBase

import copy

# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
QuantizeHandler = Any
Expand All @@ -25,7 +25,7 @@ def insert(fn):
return insert

def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
return DEFAULT_FUSION_PATTERNS
return copy.copy(DEFAULT_FUSION_PATTERNS)

DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()

Expand All @@ -47,15 +47,15 @@ def insert(fn):

# Get patterns for both static quantization and qat
def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
return DEFAULT_QUANTIZATION_PATTERNS
return copy.copy(DEFAULT_QUANTIZATION_PATTERNS)

# a map from pattern to output activation post process constructor
# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
def get_default_output_activation_post_process_map(is_training) -> Dict[Pattern, ObserverBase]:
if is_training:
return DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP
return copy.copy(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP)
else:
return DEFAULT_OUTPUT_OBSERVER_MAP
return copy.copy(DEFAULT_OUTPUT_OBSERVER_MAP)

# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
Expand Down
Loading

0 comments on commit b347b8c

Please sign in to comment.