diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/onnx/_export.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/onnx/_export.py index 551b56b4768..a1ed887f7ac 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/onnx/_export.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/onnx/_export.py @@ -225,21 +225,12 @@ def remove_quantization_nodes_from_onnx_graph(model: onnx.ModelProto): """ tensor_to_encoding_map = {} name_to_producer, name_to_consumer = _get_producer_consumer_info_from_onnx_graph(model) - node_list = list(model.graph.node) - - for node in node_list: - if node.op_type not in ONNX_QUANTIZER_OP_TYPES: - continue + qtzr_nodes = list(node for node in model.graph.node if node.op_type in ONNX_QUANTIZER_OP_TYPES) + for node in qtzr_nodes: # Get quantizer name in torch model encoding = _get_encoding_from_onnx_node(model, node) - # Remove qdq node from graph - model.graph.node.remove(node) - - # Remove scale and offset from onnx graph - _remove_constants(model, node.input[1:]) - # Connect next node to the prev node of quantizer node if node.output[0] in name_to_consumer: tensor_to_encoding_map[node.input[0]] = encoding @@ -265,6 +256,13 @@ def remove_quantization_nodes_from_onnx_graph(model: onnx.ModelProto): else: raise ValueError(f"Cannot find prev node and next node for quantization node {node.name}") + for node in qtzr_nodes: + # Remove qdq node from graph + model.graph.node.remove(node) + + # Remove scale and offset from onnx graph + _remove_constants(model, node.input[1:]) + return tensor_to_encoding_map diff --git a/TrainingExtensions/torch/test/python/v2/experimental/test_onnx.py b/TrainingExtensions/torch/test/python/v2/experimental/test_onnx.py index af20b95abe7..8ca82ae98fe 100644 --- a/TrainingExtensions/torch/test/python/v2/experimental/test_onnx.py +++ b/TrainingExtensions/torch/test/python/v2/experimental/test_onnx.py @@ -43,10 +43,11 @@ import torch import onnx import tempfile +from aimet_common.quantsim_config.utils import get_path_for_per_channel_config from aimet_common import quantsim as quantsim_common import aimet_torch.v2 as aimet import aimet_torch.v2.quantization as Q -from aimet_torch.v2.quantsim import quantsim, QuantizationSimModel +from aimet_torch.v2.quantsim import QuantizationSimModel from torchvision.models import resnet18, mobilenet_v3_small from aimet_torch.v2.experimental.onnx._export import export as _export from aimet_torch.utils import get_all_quantizers @@ -208,9 +209,11 @@ def forward(self, x: Q.QuantizedTensor): @torch.no_grad() -@pytest.mark.parametrize("model_factory, input_shape", [(resnet18, (1, 3, 224, 224)), - (mobilenet_v3_small, (1, 3, 224, 224)), - ]) +@pytest.mark.parametrize( + "model_factory, input_shape", [ + (resnet18, (1, 3, 224, 224)), + (mobilenet_v3_small, (1, 3, 224, 224)), +]) def test_export_torchvision_models(model_factory, input_shape): """ When: Export quantized torchvision model @@ -218,7 +221,7 @@ def test_export_torchvision_models(model_factory, input_shape): x = torch.randn(input_shape) model = model_factory().eval() model = prepare_model(model) - model = QuantizationSimModel(model, x).model + model = QuantizationSimModel(model, x, config_file=get_path_for_per_channel_config()).model with aimet.nn.compute_encodings(model): model(x) @@ -276,7 +279,7 @@ def test_quantsim_export_torchvision_models(model_factory, input_shape, encoding x = torch.randn(input_shape) model = model_factory().eval() model = prepare_model(model) - sim = QuantizationSimModel(model, x) + sim = QuantizationSimModel(model, x, config_file=get_path_for_per_channel_config()) sim.compute_encodings(lambda m, _: m(x), None)