Skip to content

Commit

Permalink
Fix rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Dec 27, 2023
1 parent 44390c9 commit 8c4ce45
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from nncf.common.quantization.structs import UnifiedScaleType
from tests.common.quantization.metatypes import WEIGHT_LAYER_METATYPES
from tests.common.quantization.metatypes import CatTestMetatype
from tests.common.quantization.metatypes import Conv2dTestMetatype
from tests.common.quantization.mock_graphs import get_ip_graph_for_test
from tests.common.quantization.mock_graphs import get_mock_nncf_node_attrs
from tests.common.quantization.mock_graphs import get_nncf_graph_from_mock_nx_graph
Expand Down
81 changes: 40 additions & 41 deletions tests/common/quantization/test_quantizer_propagation_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph import NNCFNodeName
from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME, NNCFGraphNodeType
from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME
from nncf.common.graph.definitions import NNCFGraphNodeType
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.graph.operator_metatypes import OutputNoopMetatype, OperatorMetatype
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.operator_metatypes import OutputNoopMetatype
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.hardware.config import HWConfig
Expand All @@ -43,11 +45,13 @@
from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from tests.common.quantization.metatypes import DEFAULT_TEST_QUANT_TRAIT_MAP, CatTestMetatype, GenericBinaryOpMetatype
from tests.common.quantization.metatypes import DEFAULT_TEST_QUANT_TRAIT_MAP
from tests.common.quantization.metatypes import BatchNormTestMetatype
from tests.common.quantization.metatypes import CatTestMetatype
from tests.common.quantization.metatypes import Conv2dTestMetatype
from tests.common.quantization.metatypes import DropoutTestMetatype
from tests.common.quantization.metatypes import GeluTestMetatype
from tests.common.quantization.metatypes import GenericBinaryOpMetatype
from tests.common.quantization.metatypes import MatMulTestMetatype
from tests.common.quantization.metatypes import MaxPool2dTestMetatype
from tests.common.quantization.metatypes import MinTestMetatype
Expand Down Expand Up @@ -112,26 +116,21 @@ def get_graph():
# (binary_op)

graph = nx.DiGraph()
input_1_attrs = {
NNCFNode.NODE_NAME_ATTR: "I1",
NNCFNode.NODE_TYPE_ATTR: NNCFGraphNodeType.INPUT_NODE
}
input_1_attrs = {NNCFNode.NODE_NAME_ATTR: "I1", NNCFNode.NODE_TYPE_ATTR: NNCFGraphNodeType.INPUT_NODE}

input_2_attrs = {
NNCFNode.NODE_NAME_ATTR: "I2",
NNCFNode.NODE_TYPE_ATTR: NNCFGraphNodeType.INPUT_NODE
}
input_2_attrs = {NNCFNode.NODE_NAME_ATTR: "I2", NNCFNode.NODE_TYPE_ATTR: NNCFGraphNodeType.INPUT_NODE}

input_3_attrs = {
NNCFNode.NODE_NAME_ATTR: "I3",
NNCFNode.NODE_TYPE_ATTR: NNCFGraphNodeType.INPUT_NODE
}
input_3_attrs = {NNCFNode.NODE_NAME_ATTR: "I3", NNCFNode.NODE_TYPE_ATTR: NNCFGraphNodeType.INPUT_NODE}

concat_node_attrs = {NNCFNode.NODE_NAME_ATTR: ConcatBeforeBinaryOp.CONCAT_NODE_NAME,
NNCFNode.NODE_TYPE_ATTR: CatTestMetatype.name}
concat_node_attrs = {
NNCFNode.NODE_NAME_ATTR: ConcatBeforeBinaryOp.CONCAT_NODE_NAME,
NNCFNode.NODE_TYPE_ATTR: CatTestMetatype.name,
}

binary_op_attrs = {NNCFNode.NODE_NAME_ATTR: ConcatBeforeBinaryOp.BINARY_OP_NAME,
NNCFNode.NODE_TYPE_ATTR: GenericBinaryOpMetatype.name}
binary_op_attrs = {
NNCFNode.NODE_NAME_ATTR: ConcatBeforeBinaryOp.BINARY_OP_NAME,
NNCFNode.NODE_TYPE_ATTR: GenericBinaryOpMetatype.name,
}

graph.add_node("I1", **input_1_attrs)
graph.add_node("I2", **input_2_attrs)
Expand Down Expand Up @@ -216,14 +215,13 @@ def __init__(
self.directly_quantized_op_node_names = directly_quantized_op_node_names




@dataclass
class TraitConfigTargetNodeStruct:
trait: QuantizationTrait
configs: List[QuantizerConfig]
target_node_name: Optional[NNCFNodeName]


@dataclass
class PathTransitionTestStruct:
init_node_to_trait_configs_and_target_node_dict: Dict[str, TraitConfigTargetNodeStruct]
Expand All @@ -232,12 +230,14 @@ class PathTransitionTestStruct:
target_node_for_primary_quantizer: str
expected_status: TransitionStatus


@dataclass
class InitNodeTestStruct:
quantization_trait: QuantizationTrait
qconfigs: List[QuantizerConfig]
op_meta: Type[OperatorMetatype] = UnknownMetatype


class RunOnIpGraphTestStruct:
def __init__(
self,
Expand Down Expand Up @@ -904,7 +904,6 @@ def test_merged_qconfig_list_is_independent_of_branch_qconfig_list_order(
),
)


BRANCH_TRANSITION_TEST_CASES = [
# Downward branches are quantization-agnostic
BranchTransitionTestStruct(
Expand Down Expand Up @@ -1225,11 +1224,10 @@ def test_check_branching_transition(self, branch_transition_test_struct: BranchT
status = solver.check_branching_transition(quant_prop_graph, primary_prop_quant, target_node)
assert status == expected_status


@staticmethod
def prepare_propagation_graph_state(
ip_graph: InsertionPointGraph, init_node_to_trait_configs_and_target_node_dict: Dict[str,
TraitConfigTargetNodeStruct]
ip_graph: InsertionPointGraph,
init_node_to_trait_configs_and_target_node_dict: Dict[str, TraitConfigTargetNodeStruct],
) -> Tuple[List[PropagatingQuantizer], QPSG]:
quant_prop_graph = QPSG(ip_graph)
prop_quantizers = []
Expand Down Expand Up @@ -1489,7 +1487,7 @@ def test_check_transition_via_path(self, path_transition_test_struct: PathTransi
PropagationStepTestStruct(
init_node_to_trait_configs_and_target_node_dict=
{
'6 /F_0': (QuantizationTrait.INPUTS_QUANTIZABLE,
'6 /F_0': TraitConfigTargetNodeStruct(QuantizationTrait.INPUTS_QUANTIZABLE,
[QuantizerConfig()],
InsertionPointGraph.get_post_hook_node_key('0 /O_0'))
},
Expand All @@ -1500,7 +1498,7 @@ def test_check_transition_via_path(self, path_transition_test_struct: PathTransi
PropagationStepTestStruct(
init_node_to_trait_configs_and_target_node_dict=
{
'6 /F_0': (QuantizationTrait.INPUTS_QUANTIZABLE,
'6 /F_0': TraitConfigTargetNodeStruct(QuantizationTrait.INPUTS_QUANTIZABLE,
[QuantizerConfig()],
InsertionPointGraph.get_pre_hook_node_key('3 /C_0'))
},
Expand All @@ -1510,13 +1508,13 @@ def test_check_transition_via_path(self, path_transition_test_struct: PathTransi
),
PropagationStepTestStruct(
init_node_to_trait_configs_and_target_node_dict={
'6 /F_0': (QuantizationTrait.INPUTS_QUANTIZABLE,
'6 /F_0': TraitConfigTargetNodeStruct(QuantizationTrait.INPUTS_QUANTIZABLE,
[QuantizerConfig()],
InsertionPointGraph.get_pre_hook_node_key('1 /A_0')),
'7 /G_0': (QuantizationTrait.INPUTS_QUANTIZABLE,
'7 /G_0': TraitConfigTargetNodeStruct(QuantizationTrait.INPUTS_QUANTIZABLE,
[QuantizerConfig()],
InsertionPointGraph.get_pre_hook_node_key('5 /E_0')),
'10 /J_0': (QuantizationTrait.INPUTS_QUANTIZABLE,
'10 /J_0': TraitConfigTargetNodeStruct(QuantizationTrait.INPUTS_QUANTIZABLE,
[QuantizerConfig()],
InsertionPointGraph.get_pre_hook_node_key('9 /I_0'))
},
Expand All @@ -1529,13 +1527,13 @@ def test_check_transition_via_path(self, path_transition_test_struct: PathTransi
# (i.e. when passing through an upward branching node)
PropagationStepTestStruct(
init_node_to_trait_configs_and_target_node_dict={
'10 /J_0': (QuantizationTrait.INPUTS_QUANTIZABLE,
'10 /J_0': TraitConfigTargetNodeStruct(QuantizationTrait.INPUTS_QUANTIZABLE,
[QuantizerConfig()],
InsertionPointGraph.get_post_hook_node_key('9 /I_0')),
'6 /F_0': (QuantizationTrait.INPUTS_QUANTIZABLE,
'6 /F_0': TraitConfigTargetNodeStruct(QuantizationTrait.INPUTS_QUANTIZABLE,
[QuantizerConfig()],
InsertionPointGraph.get_pre_hook_node_key('1 /A_0')),
'7 /G_0': (QuantizationTrait.INPUTS_QUANTIZABLE,
'7 /G_0': TraitConfigTargetNodeStruct(QuantizationTrait.INPUTS_QUANTIZABLE,
[QuantizerConfig()],
InsertionPointGraph.get_pre_hook_node_key('5 /E_0')),
},
Expand All @@ -1554,6 +1552,7 @@ def test_propagation_step(self, propagation_step_test_struct: PropagationStepTes
init_node_to_trait_configs_and_target_node_dict = (
propagation_step_test_struct.init_node_to_trait_configs_and_target_node_dict
)

expected_finished_status = propagation_step_test_struct.expected_finished_status
current_location_node_key_for_propagated_quant = (
propagation_step_test_struct.current_location_node_key_for_propagated_quant
Expand Down Expand Up @@ -1867,7 +1866,7 @@ def test_quantizers_are_not_propagated_through_integer_paths(
):
quant_prop_solver = QuantizerPropagationSolver()
prep_data_dict = {
int_prop_test_struct.initial_node_name: (
int_prop_test_struct.initial_node_name: TraitConfigTargetNodeStruct(
QuantizationTrait.INPUTS_QUANTIZABLE,
[QuantizerConfig()],
int_prop_test_struct.target_node_name,
Expand Down Expand Up @@ -1943,15 +1942,15 @@ def get_operations_with_unified_scales(self) -> Set[Type[OperatorMetatype]]:
return {CatTestMetatype, GenericBinaryOpMetatype}

def get_metatype_vs_quantizer_configs_map(
self, for_weights=False
self, for_weights=False
) -> Dict[Type[OperatorMetatype], Optional[List[QuantizerConfig]]]:
return {GenericBinaryOpMetatype: [QuantizerConfig(per_channel=False)],
CatTestMetatype: []}
return {GenericBinaryOpMetatype: [QuantizerConfig(per_channel=False)], CatTestMetatype: []}

quant_prop_solver = QuantizerPropagationSolver(
run_consistency_checks=True, default_trait_to_metatype_map=DEFAULT_TEST_QUANT_TRAIT_MAP,
hw_config=TestHWConfig())
run_consistency_checks=True,
default_trait_to_metatype_map=DEFAULT_TEST_QUANT_TRAIT_MAP,
hw_config=TestHWConfig(),
)
nncf_graph = get_nncf_graph_from_mock_nx_graph(graph_cls.get_graph())
ip_graph = get_ip_graph_for_test(nncf_graph)
retval = quant_prop_solver.run_on_ip_graph(ip_graph)
pass
_ = quant_prop_solver.run_on_ip_graph(ip_graph)
2 changes: 1 addition & 1 deletion tests/common/quantization/test_quantizer_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@


def test_unified_scale_assignment_based_on_qconfig_selection():
pass
pass
16 changes: 7 additions & 9 deletions tests/torch/quantization/test_unified_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import itertools
from collections import Counter
from functools import partial
from typing import Callable
from typing import Dict, List
from typing import Callable, Dict, List

import onnx
import pytest
Expand Down Expand Up @@ -454,13 +453,10 @@ class UnifiedScaleTestStruct:


CAT_UNIFIED_SCALE_TEST_STRUCTS = [
UnifiedScaleTestStruct(
model_builder=SingleCatModel,
ref_aq_module_count=3,
ref_quantizations_count=4)
,
UnifiedScaleTestStruct(model_builder=SingleCatModel, ref_aq_module_count=3, ref_quantizations_count=4),
UnifiedScaleTestStruct(model_builder=DoubleCatModel, ref_aq_module_count=3, ref_quantizations_count=4),
UnifiedScaleTestStruct(model_builder=UNetLikeModel, ref_aq_module_count=4, ref_quantizations_count=6)]
UnifiedScaleTestStruct(model_builder=UNetLikeModel, ref_aq_module_count=4, ref_quantizations_count=6),
]


@pytest.mark.parametrize(
Expand All @@ -484,7 +480,9 @@ def test_unified_scales_with_concat(target_device, unified_scale_test_case: Unif
nncf_config["target_device"] = target_device
register_bn_adaptation_init_args(nncf_config)

_, compression_ctrl = create_compressed_model_and_algo_for_test(unified_scale_test_case.model_builder(), nncf_config)
_, compression_ctrl = create_compressed_model_and_algo_for_test(
unified_scale_test_case.model_builder(), nncf_config
)

assert len(compression_ctrl.non_weight_quantizers) == unified_scale_test_case.ref_aq_module_count

Expand Down

0 comments on commit 8c4ce45

Please sign in to comment.