Skip to content

Commit

Permalink
Fix sim.onnx.export bug in per-channel quantization (#3825)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Feb 20, 2025
1 parent 78a4f7f commit 2b8f52a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,12 @@ def remove_quantization_nodes_from_onnx_graph(model: onnx.ModelProto):
"""
tensor_to_encoding_map = {}
name_to_producer, name_to_consumer = _get_producer_consumer_info_from_onnx_graph(model)
node_list = list(model.graph.node)

for node in node_list:
if node.op_type not in ONNX_QUANTIZER_OP_TYPES:
continue
qtzr_nodes = list(node for node in model.graph.node if node.op_type in ONNX_QUANTIZER_OP_TYPES)

for node in qtzr_nodes:
# Get quantizer name in torch model
encoding = _get_encoding_from_onnx_node(model, node)

# Remove qdq node from graph
model.graph.node.remove(node)

# Remove scale and offset from onnx graph
_remove_constants(model, node.input[1:])

# Connect next node to the prev node of quantizer node
if node.output[0] in name_to_consumer:
tensor_to_encoding_map[node.input[0]] = encoding
Expand All @@ -265,6 +256,13 @@ def remove_quantization_nodes_from_onnx_graph(model: onnx.ModelProto):
else:
raise ValueError(f"Cannot find prev node and next node for quantization node {node.name}")

for node in qtzr_nodes:
# Remove qdq node from graph
model.graph.node.remove(node)

# Remove scale and offset from onnx graph
_remove_constants(model, node.input[1:])

return tensor_to_encoding_map


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@
import torch
import onnx
import tempfile
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_common import quantsim as quantsim_common
import aimet_torch.v2 as aimet
import aimet_torch.v2.quantization as Q
from aimet_torch.v2.quantsim import quantsim, QuantizationSimModel
from aimet_torch.v2.quantsim import QuantizationSimModel
from torchvision.models import resnet18, mobilenet_v3_small
from aimet_torch.v2.experimental.onnx._export import export as _export
from aimet_torch.utils import get_all_quantizers
Expand Down Expand Up @@ -208,17 +209,19 @@ def forward(self, x: Q.QuantizedTensor):


@torch.no_grad()
@pytest.mark.parametrize("model_factory, input_shape", [(resnet18, (1, 3, 224, 224)),
(mobilenet_v3_small, (1, 3, 224, 224)),
])
@pytest.mark.parametrize(
"model_factory, input_shape", [
(resnet18, (1, 3, 224, 224)),
(mobilenet_v3_small, (1, 3, 224, 224)),
])
def test_export_torchvision_models(model_factory, input_shape):
"""
When: Export quantized torchvision model
"""
x = torch.randn(input_shape)
model = model_factory().eval()
model = prepare_model(model)
model = QuantizationSimModel(model, x).model
model = QuantizationSimModel(model, x, config_file=get_path_for_per_channel_config()).model

with aimet.nn.compute_encodings(model):
model(x)
Expand Down Expand Up @@ -276,7 +279,7 @@ def test_quantsim_export_torchvision_models(model_factory, input_shape, encoding
x = torch.randn(input_shape)
model = model_factory().eval()
model = prepare_model(model)
sim = QuantizationSimModel(model, x)
sim = QuantizationSimModel(model, x, config_file=get_path_for_per_channel_config())

sim.compute_encodings(lambda m, _: m(x), None)

Expand Down

0 comments on commit 2b8f52a

Please sign in to comment.