Skip to content

Commit

Permalink
[TorchFX] PTQ MinMax algorithm test (#3000)
Browse files Browse the repository at this point in the history
### Changes

1. Added a new test file in tests/torch/fx. Implemented
`TemplateTestMinMaxAlgorithm` for Torch Fx backend.

2. Fixed a doc string mistake in
`tests/cross_fw/test_templates/test_min_max.py`. Removed the line of
code in the test `test_get_channel_axes_matmul_torch` that assigns the
matmul node's layer attributes.

3. Made changes to `tests/torch/ptq/test_min_max.py` to pass the
functions that handle backend-specific layer attributes.

### Reason for changes

For changes 2 and 3, only the node metatypes are required to test target
point shape and weight quantization channel axes.

### Related tickets

#2872
  • Loading branch information
rk119 authored Oct 7, 2024
1 parent 174bd03 commit 63669db
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 48 deletions.
5 changes: 2 additions & 3 deletions tests/cross_fw/test_templates/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class TemplateTestMinMaxAlgorithm:
@abstractmethod
def backend(self) -> MinMaxAlgoBackend:
"""
Get backend specific BiasCorrectionAlgoBackend
Get backend specific MinMaxAlgoBackend
:return BiasCorrectionAlgoBackend: Backend specific BiasCorrectionAlgoBackend
:return MinMaxAlgoBackend: Backend specific MinMaxAlgoBackend
"""

@property
Expand Down Expand Up @@ -164,7 +164,6 @@ def test_get_channel_axes_matmul_torch(self, weight_shape, ref_axes):
Checks MatMul quantization axes in MinMax for Torch.
"""
matmul_node = NNCFNode({"metatype": self.matmul_metatype})
matmul_node.layer_attributes = self.get_matmul_node_attrs(None, None, weight_shape)

class DummyTargetPoint:
input_port_id = 0
Expand Down
85 changes: 85 additions & 0 deletions tests/torch/fx/test_min_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import pytest

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.transformations.commands import TargetType
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.algorithms.min_max.torch_fx_backend import FXMinMaxAlgoBackend
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import PTConstNoopMetatype
from nncf.torch.graph.operator_metatypes import PTConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTLinearMetatype
from nncf.torch.graph.transformations.commands import PTTargetPoint
from tests.cross_fw.test_templates.models import NNCFGraphToTest
from tests.cross_fw.test_templates.test_min_max import TemplateTestGetChannelAxes
from tests.cross_fw.test_templates.test_min_max import TemplateTestGetTargetPointShape
from tests.cross_fw.test_templates.test_min_max import TemplateTestMinMaxAlgorithm


class TestTorchFXMinMaxAlgorithm(TemplateTestMinMaxAlgorithm):
@property
def backend(self) -> MinMaxAlgoBackend:
return FXMinMaxAlgoBackend

@property
def conv_metatype(self):
return PTConv2dMetatype

def create_target_point(self, target_point_type: TargetType, name: str, port_id: int) -> PTTargetPoint:
if target_point_type == TargetType.POST_LAYER_OPERATION:
port_id = None
return PTTargetPoint(target_point_type, name, input_port_id=port_id)


class TestTorchFXGetTargetPointShape(TemplateTestGetTargetPointShape, TestTorchFXMinMaxAlgorithm):
def get_nncf_graph(self, weight_port_id: int, weight_shape: Tuple[int]) -> NNCFGraph:
return NNCFGraphToTest(
conv_metatype=PTConv2dMetatype, nncf_graph_cls=PTNNCFGraph, const_metatype=PTConstNoopMetatype
).nncf_graph


class TestTorchFXGetChannelAxes(TemplateTestGetChannelAxes, TestTorchFXMinMaxAlgorithm):
@property
def depthwiseconv_metatype(self):
return PTDepthwiseConv2dSubtype

@property
def matmul_metatype(self):
return PTLinearMetatype

@staticmethod
def get_conv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> BaseLayerAttributes:
# This method isn't needed for Torch FX backend
return None

@staticmethod
def get_depthwiseconv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> BaseLayerAttributes:
# This method isn't needed for Torch FX backend
return None

@staticmethod
def get_matmul_node_attrs(
weight_port_id: int, transpose_weight: Tuple[int], weight_shape: Tuple[int]
) -> BaseLayerAttributes:
# This method isn't needed for Torch FX backend
return None

def test_get_channel_axes_matmul_node_ov_onnx(self):
pytest.skip("Test is not applied for Torch FX backend.")

def test_get_channel_axes_deptwiseconv_node_ov(self):
pytest.skip("Test is not applied for Torch FX backend.")
58 changes: 13 additions & 45 deletions tests/torch/ptq/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
import pytest

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.layer_attributes import ConstantLayerAttributes
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
from nncf.common.graph.layer_attributes import LinearLayerAttributes
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.transformations.commands import TargetType
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend
Expand Down Expand Up @@ -48,23 +46,8 @@ def create_target_point(self, target_point_type: TargetType, name: str, port_id:

class TestTorchGetTargetPointShape(TemplateTestGetTargetPointShape, TestTorchMinMaxAlgorithm):
def get_nncf_graph(self, weight_port_id: int, weight_shape: Tuple[int]) -> NNCFGraph:
conv_layer_attrs = ConvolutionLayerAttributes(
weight_requires_grad=True,
in_channels=weight_shape[1],
out_channels=weight_shape[0],
kernel_size=weight_shape[2:],
stride=1,
dilations=1,
groups=1,
transpose=False,
padding_values=[],
)
return NNCFGraphToTest(
PTConv2dMetatype,
conv_layer_attrs,
PTNNCFGraph,
const_metatype=PTConstNoopMetatype,
const_layer_attrs=ConstantLayerAttributes("w", shape=weight_shape),
conv_metatype=PTConv2dMetatype, nncf_graph_cls=PTNNCFGraph, const_metatype=PTConstNoopMetatype
).nncf_graph


Expand All @@ -78,36 +61,21 @@ def matmul_metatype(self):
return PTLinearMetatype

@staticmethod
def get_conv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> ConvolutionLayerAttributes:
return ConvolutionLayerAttributes(
weight_requires_grad=False,
in_channels=weight_shape[0],
out_channels=weight_shape[1],
kernel_size=weight_shape[2:],
stride=1,
dilations=1,
groups=1,
transpose=False,
padding_values=[],
)
def get_conv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> BaseLayerAttributes:
# This method isn't needed for Torch backend
return None

@staticmethod
def get_depthwiseconv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> ConvolutionLayerAttributes:
return ConvolutionLayerAttributes(
weight_requires_grad=False,
in_channels=weight_shape[1],
out_channels=weight_shape[2],
kernel_size=weight_shape[3:],
stride=1,
dilations=1,
groups=weight_shape[0],
transpose=False,
padding_values=[],
)
def get_depthwiseconv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> BaseLayerAttributes:
# This method isn't needed for Torch backend
return None

@staticmethod
def get_matmul_node_attrs(weight_port_id: int, transpose_weight: Tuple[int], weight_shape: Tuple[int]):
return LinearLayerAttributes(False, in_features=weight_shape[0], out_features=weight_shape[1])
def get_matmul_node_attrs(
weight_port_id: int, transpose_weight: Tuple[int], weight_shape: Tuple[int]
) -> BaseLayerAttributes:
# This method isn't needed for Torch backend
return None

def test_get_channel_axes_matmul_node_ov_onnx(self):
pytest.skip("Test is not applied for Torch backend.")
Expand Down

0 comments on commit 63669db

Please sign in to comment.