diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py index 019454642..16934a3ad 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py @@ -12,23 +12,72 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -import copy - +from dataclasses import replace, dataclass, asdict, field from enum import Enum - -import pprint - from typing import Dict, Any, Union, Tuple, List, Optional - from mct_quantizers import QuantizationMethod from model_compression_toolkit.constants import FLOAT_BITWIDTH - from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST -from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import \ - get_current_tp_model, _current_tp_model -from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import clone_and_edit_object_params + _current_tp_model + +class OperatorSetNames(Enum): + OPSET_NO_QUANTIZATION = "NoQuantization" + OPSET_QUANTIZATION_PRESERVING = "QuantizationPreserving" + OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS = "DimensionManipulationOpsWithWeights" + OPSET_DIMENSION_MANIPULATION_OPS = "DimensionManipulationOps" + OPSET_MERGE_OPS = "MergeOps" + OPSET_CONV = "Conv" + OPSET_DEPTHWISE_CONV = "DepthwiseConv2D" + OPSET_CONV_TRANSPOSE = "ConvTraspose" + OPSET_FULLY_CONNECTED = "FullyConnected" + OPSET_CONCATENATE = "Concatenate" + OPSET_STACK = "Stack" + OPSET_UNSTACK = "Unstack" + OPSET_GATHER = "Gather" + OPSET_EXPAND = "Expend" + OPSET_BATCH_NORM = "BatchNorm" + OPSET_ANY_RELU = "AnyReLU" + OPSET_ADD = "Add" + OPSET_SUB = "Sub" + OPSET_MUL = "Mul" + OPSET_DIV = "Div" + OPSET_MIN_MAX = "MinMax" + OPSET_PRELU = "PReLU" + OPSET_SWISH = "Swish" + OPSET_SIGMOID = "Sigmoid" + OPSET_TANH = "Tanh" + OPSET_GELU = "Gelu" + OPSET_HARDSIGMOID = "HardSigmoid" + OPSET_HARDSWISH = "HardSwish" + OPSET_FLATTEN = "Flatten" + OPSET_GET_ITEM = "GetItem" + OPSET_RESHAPE = "Reshape" + OPSET_UNSQUEEZE = "Unsqueeze" + OPSET_SQUEEZE = "Squeeze" + OPSET_PERMUTE = "Permute" + OPSET_TRANSPOSE = "Transpose" + OPSET_DROPOUT = "Dropout" + OPSET_SPLIT = "Split" + OPSET_CHUNK = "Chunk" + OPSET_UNBIND = "Unbind" + OPSET_MAXPOOL = "MaxPool" + OPSET_SIZE = "Size" + OPSET_SHAPE = "Shape" + OPSET_EQUAL = "Equal" + OPSET_ARGMAX = "ArgMax" + OPSET_TOPK = "TopK" + OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS = "FakeQuantWithMinMaxVars" + OPSET_COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression" + OPSET_CROPPING2D = "Cropping2D" + OPSET_ZERO_PADDING2d = "ZeroPadding2D" + OPSET_CAST = "Cast" + OPSET_STRIDED_SLICE = "StridedSlice" + + @classmethod + def get_values(cls): + return [v.value for v in cls] class Signedness(Enum): @@ -44,451 +93,431 @@ class Signedness(Enum): UNSIGNED = 2 +@dataclass(frozen=True) class AttributeQuantizationConfig: """ - Hold the quantization configuration of a weight attribute of a layer. - """ - def __init__(self, - weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO, - weights_n_bits: int = FLOAT_BITWIDTH, - weights_per_channel_threshold: bool = False, - enable_weights_quantization: bool = False, - lut_values_bitwidth: Union[int, None] = None, # If None - set 8 in hptq, o.w use it - ): - """ - Initializes an attribute quantization config. + Holds the quantization configuration of a weight attribute of a layer. - Args: - weights_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for weights quantization. - weights_n_bits (int): Number of bits to quantize the coefficients. - weights_per_channel_threshold (bool): Whether to quantize the weights per-channel or not (per-tensor). - enable_weights_quantization (bool): Whether to quantize the model weights or not. - lut_values_bitwidth (int): Number of bits to use when quantizing in look-up-table. - - """ - - self.weights_quantization_method = weights_quantization_method - self.weights_n_bits = weights_n_bits - self.weights_per_channel_threshold = weights_per_channel_threshold - self.enable_weights_quantization = enable_weights_quantization - self.lut_values_bitwidth = lut_values_bitwidth + Attributes: + weights_quantization_method (QuantizationMethod): The method to use from QuantizationMethod for weights quantization. + weights_n_bits (int): Number of bits to quantize the coefficients. + weights_per_channel_threshold (bool): Indicates whether to quantize the weights per-channel or per-tensor. + enable_weights_quantization (bool): Indicates whether to quantize the model weights or not. + lut_values_bitwidth (Union[int, None]): Number of bits to use when quantizing in a look-up table. + If None, defaults to 8 in hptq; otherwise, it uses the provided value. + """ + weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO + weights_n_bits: int = FLOAT_BITWIDTH + weights_per_channel_threshold: bool = False + enable_weights_quantization: bool = False + lut_values_bitwidth: Union[int, None] = None - def clone_and_edit(self, **kwargs): + def __post_init__(self): """ - Clone the quantization config and edit some of its attributes. + Post-initialization processing for input validation. - Args: - **kwargs: Keyword arguments to edit the configuration to clone. - - Returns: - Edited quantization configuration. + Raises: + Logger critical if attributes are of incorrect type or have invalid values. """ + if not isinstance(self.weights_n_bits, int) or self.weights_n_bits < 1: + Logger.critical("weights_n_bits must be a positive integer.") + if not isinstance(self.enable_weights_quantization, bool): + Logger.critical("enable_weights_quantization must be a boolean.") + if self.lut_values_bitwidth is not None and not isinstance(self.lut_values_bitwidth, int): + Logger.critical("lut_values_bitwidth must be an integer or None.") - return clone_and_edit_object_params(self, **kwargs) - - def __eq__(self, other): + def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig': """ - Is this configuration equal to another object. + Clone the current AttributeQuantizationConfig and edit some of its attributes. Args: - other: Object to compare. + **kwargs: Keyword arguments representing the attributes to edit in the cloned instance. Returns: - - Whether this configuration is equal to another object or not. + AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes. """ - if not isinstance(other, AttributeQuantizationConfig): - return False # pragma: no cover - return self.weights_quantization_method == other.weights_quantization_method and \ - self.weights_n_bits == other.weights_n_bits and \ - self.weights_per_channel_threshold == other.weights_per_channel_threshold and \ - self.enable_weights_quantization == other.enable_weights_quantization and \ - self.lut_values_bitwidth == other.lut_values_bitwidth + return replace(self, **kwargs) +@dataclass(frozen=True) class OpQuantizationConfig: """ OpQuantizationConfig is a class to configure the quantization parameters of an operator. - """ - - def __init__(self, - default_weight_attr_config: AttributeQuantizationConfig, - attr_weights_configs_mapping: Dict[str, AttributeQuantizationConfig], - activation_quantization_method: QuantizationMethod, - activation_n_bits: int, - supported_input_activation_n_bits: Union[int, Tuple[int]], - enable_activation_quantization: bool, - quantization_preserving: bool, - fixed_scale: float, - fixed_zero_point: int, - simd_size: int, - signedness: Signedness - ): - """ - - Args: - default_weight_attr_config (AttributeQuantizationConfig): A default attribute quantization configuration for the operation. - attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration. - activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization. - activation_n_bits (int): Number of bits to quantize the activations. - supported_input_activation_n_bits (int or Tuple[int]): Number of bits that operator accepts as input. - enable_activation_quantization (bool): Whether to quantize the model activations or not. - quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output. - fixed_scale (float): Scale to use for an operator quantization parameters. - fixed_zero_point (int): Zero-point to use for an operator quantization parameters. - simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction. - signedness (bool): Set activation quantization signedness. - - """ - self.default_weight_attr_config = default_weight_attr_config - self.attr_weights_configs_mapping = attr_weights_configs_mapping + Args: + default_weight_attr_config (AttributeQuantizationConfig): A default attribute quantization configuration for the operation. + attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration. + activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization. + activation_n_bits (int): Number of bits to quantize the activations. + supported_input_activation_n_bits (int or Tuple[int]): Number of bits that operator accepts as input. + enable_activation_quantization (bool): Whether to quantize the model activations or not. + quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output. + fixed_scale (float): Scale to use for an operator quantization parameters. + fixed_zero_point (int): Zero-point to use for an operator quantization parameters. + simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction. + signedness (bool): Set activation quantization signedness. - self.activation_quantization_method = activation_quantization_method - self.activation_n_bits = activation_n_bits - if isinstance(supported_input_activation_n_bits, tuple): - self.supported_input_activation_n_bits = supported_input_activation_n_bits - elif isinstance(supported_input_activation_n_bits, int): - self.supported_input_activation_n_bits = (supported_input_activation_n_bits,) - else: - Logger.critical(f"Supported_input_activation_n_bits only accepts int or tuple of ints, but got {type(supported_input_activation_n_bits)}") # pragma: no cover - self.enable_activation_quantization = enable_activation_quantization - self.quantization_preserving = quantization_preserving - self.fixed_scale = fixed_scale - self.fixed_zero_point = fixed_zero_point - self.signedness = signedness - self.simd_size = simd_size + """ + default_weight_attr_config: AttributeQuantizationConfig + attr_weights_configs_mapping: Dict[str, AttributeQuantizationConfig] + activation_quantization_method: QuantizationMethod + activation_n_bits: int + supported_input_activation_n_bits: Union[int, Tuple[int]] + enable_activation_quantization: bool + quantization_preserving: bool + fixed_scale: float + fixed_zero_point: int + simd_size: int + signedness: Signedness + + def __post_init__(self) -> None: + """ + Post-initialization processing for input validation. + + Raises: + Logger critical if supported_input_activation_n_bits is not an int or a tuple of ints. + """ + if isinstance(self.supported_input_activation_n_bits, int): + object.__setattr__(self, 'supported_input_activation_n_bits', (self.supported_input_activation_n_bits,)) + elif not isinstance(self.supported_input_activation_n_bits, tuple): + Logger.critical( + f"Supported_input_activation_n_bits only accepts int or tuple of ints, but got {type(self.supported_input_activation_n_bits)}") # pragma: no cover - def get_info(self): + def get_info(self) -> Dict[str, Any]: """ + Get information about the quantization configuration. - Returns: Info about the quantization configuration as a dictionary. - + Returns: + dict: Information about the quantization configuration as a dictionary. """ - return self.__dict__ # pragma: no cover + return asdict(self) # pragma: no cover - def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs): + def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs) -> 'OpQuantizationConfig': """ Clone the quantization config and edit some of its attributes. + Args: - attr_to_edit: A mapping between attributes names to edit and their parameters that - should be edited to a new value. + attr_to_edit (Dict[str, Dict[str, Any]]): A mapping between attribute names to edit and their parameters that + should be edited to a new value. **kwargs: Keyword arguments to edit the configuration to clone. Returns: - Edited quantization configuration. + OpQuantizationConfig: Edited quantization configuration. """ - qc = clone_and_edit_object_params(self, **kwargs) - - # optionally: editing specific parameters in the config of specified attributes - edited_attrs = copy.deepcopy(qc.attr_weights_configs_mapping) - for attr_name, attr_cfg in qc.attr_weights_configs_mapping.items(): - if attr_name in attr_to_edit: - edited_attrs[attr_name] = attr_cfg.clone_and_edit(**attr_to_edit[attr_name]) - - qc.attr_weights_configs_mapping = edited_attrs + # Clone and update top-level attributes + updated_config = replace(self, **kwargs) - return qc + # Clone and update nested immutable dataclasses in `attr_weights_configs_mapping` + updated_attr_mapping = { + attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name]) + if attr_name in attr_to_edit else attr_cfg) + for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items() + } - def __eq__(self, other): - """ - Is this configuration equal to another object. - Args: - other: Object to compare. - - Returns: - Whether this configuration is equal to another object or not. - """ - if not isinstance(other, OpQuantizationConfig): - return False # pragma: no cover - return self.default_weight_attr_config == other.default_weight_attr_config and \ - self.attr_weights_configs_mapping == other.attr_weights_configs_mapping and \ - self.activation_quantization_method == other.activation_quantization_method and \ - self.activation_n_bits == other.activation_n_bits and \ - self.supported_input_activation_n_bits == other.supported_input_activation_n_bits and \ - self.enable_activation_quantization == other.enable_activation_quantization and \ - self.signedness == other.signedness and \ - self.simd_size == other.simd_size + # Return a new instance with the updated attribute mapping + return replace(updated_config, attr_weights_configs_mapping=updated_attr_mapping) @property def max_input_activation_n_bits(self) -> int: """ - Get maximum supported input bit-width. - - Returns: Maximum supported input bit-width. + Get the maximum supported input bit-width. + Returns: + int: Maximum supported input bit-width. """ return max(self.supported_input_activation_n_bits) +@dataclass(frozen=True) class QuantizationConfigOptions: """ + QuantizationConfigOptions wraps a set of quantization configurations to consider during the quantization of an operator. - Wrap a set of quantization configurations to consider during the quantization - of an operator. - + Attributes: + quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather. + base_config (Union[OpQuantizationConfig, None]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner. """ - def __init__(self, - quantization_config_list: List[OpQuantizationConfig], - base_config: OpQuantizationConfig = None): - """ + quantization_config_list: List[OpQuantizationConfig] + base_config: Union[OpQuantizationConfig, None] = None - Args: - quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather. - base_config (OpQuantizationConfig): Fallback OpQuantizationConfig to use when optimizing the model in a non mixed-precision manner. - """ - - assert isinstance(quantization_config_list, - list), f"'QuantizationConfigOptions' options list must be a list, but received: {type(quantization_config_list)}." - for cfg in quantization_config_list: - assert isinstance(cfg, OpQuantizationConfig),\ - f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}." - self.quantization_config_list = quantization_config_list - if len(quantization_config_list) > 1: - assert base_config is not None, \ - f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization." - assert any([base_config is cfg for cfg in quantization_config_list]), \ - f"'base_config' must be included in the quantization config options list." - # Enforce base_config to be a reference to an instance in quantization_config_list. - self.base_config = base_config - elif len(quantization_config_list) == 1: - assert base_config is None or base_config == quantization_config_list[0], "'base_config' should be included in 'quantization_config_list'" - # Set base_config to be a reference to the first instance in quantization_config_list. - self.base_config = quantization_config_list[0] - else: - raise AssertionError("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") - - def __eq__(self, other): - """ - Is this QCOptions equal to another object. - Args: - other: Object to compare. + def __post_init__(self) -> None: + """ + Post-initialization processing for input validation. - Returns: - Whether this QCOptions equal to another object or not. + Raises: + Logger critical if quantization_config_list is not a list, contains invalid elements, or if base_config is not set correctly. """ + # Validate `quantization_config_list` + if not isinstance(self.quantization_config_list, list): + Logger.critical( + f"'quantization_config_list' must be a list, but received: {type(self.quantization_config_list)}.") + for cfg in self.quantization_config_list: + if not isinstance(cfg, OpQuantizationConfig): + Logger.critical( + f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}.") - if not isinstance(other, QuantizationConfigOptions): - return False - if len(self.quantization_config_list) != len(other.quantization_config_list): - return False - for qc, other_qc in zip(self.quantization_config_list, other.quantization_config_list): - if qc != other_qc: - return False - return True + # Handle base_config + if len(self.quantization_config_list) > 1: + if self.base_config is None: + Logger.critical(f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization.") + if not any(self.base_config == cfg for cfg in self.quantization_config_list): + Logger.critical(f"'base_config' must be included in the quantization config options list.") + elif len(self.quantization_config_list) == 1: + if self.base_config is None: + object.__setattr__(self, 'base_config', self.quantization_config_list[0]) + elif self.base_config != self.quantization_config_list[0]: + Logger.critical( + "'base_config' should be the same as the sole item in 'quantization_config_list'.") - def clone_and_edit(self, **kwargs): - qc_options = copy.deepcopy(self) - for qc in qc_options.quantization_config_list: - self.__edit_quantization_configuration(qc, kwargs) - return qc_options + elif len(self.quantization_config_list) == 0: + Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") - def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs): + def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions': """ - Clones the quantization configurations and edits some of their attributes' parameters. + Clone the quantization configuration options and edit attributes in each configuration. Args: - attrs: attributes names to clone their configurations. If None is provided, updating the configurations - of all attributes in the operation attributes config mapping. - **kwargs: Keyword arguments to edit in the attributes configuration. + **kwargs: Keyword arguments to edit in each configuration. Returns: - QuantizationConfigOptions with edited attributes configurations. + A new instance of QuantizationConfigOptions with updated configurations. + """ + updated_base_config = replace(self.base_config, **kwargs) + updated_configs_list = [ + replace(cfg, **kwargs) for cfg in self.quantization_config_list + ] + return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs_list) + def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) -> 'QuantizationConfigOptions': """ + Clones the quantization configurations and edits some of their attributes' parameters. - qc_options = copy.deepcopy(self) + Args: + attrs (List[str]): Attributes names to clone and edit their configurations. If None, updates all attributes. + **kwargs: Keyword arguments to edit in the attributes configuration. - for qc in qc_options.quantization_config_list: + Returns: + QuantizationConfigOptions: A new instance of QuantizationConfigOptions with edited attributes configurations. + """ + updated_base_config = self.base_config + updated_configs = [] + for qc in self.quantization_config_list: if attrs is None: attrs_to_update = list(qc.attr_weights_configs_mapping.keys()) else: - if not isinstance(attrs, List): # pragma: no cover - Logger.critical(f"Expected a list of attributes but received {type(attrs)}.") attrs_to_update = attrs - + # Ensure all attributes exist in the config for attr in attrs_to_update: - if qc.attr_weights_configs_mapping.get(attr) is None: # pragma: no cover - Logger.critical(f'Editing attributes is only possible for existing attributes in the configuration\'s ' - f'weights config mapping; {attr} does not exist in {qc}.') - self.__edit_quantization_configuration(qc.attr_weights_configs_mapping[attr], kwargs) - return qc_options + if attr not in qc.attr_weights_configs_mapping: + Logger.critical(f"{attr} does not exist in {qc}.") + updated_attr_mapping = { + attr: qc.attr_weights_configs_mapping[attr].clone_and_edit(**kwargs) + for attr in attrs_to_update + } + if qc == updated_base_config: + updated_base_config = replace(updated_base_config, attr_weights_configs_mapping=updated_attr_mapping) + updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping)) + return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs) - def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Union[Dict[str, str], None]): + def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Union[Dict[str, str], None]) -> 'QuantizationConfigOptions': """ - Clones the quantization configuration options and edits the keys in each configuration attributes config mapping, - based on the given attributes names mapping. + Clones the quantization configurations and updates keys in attribute config mappings. Args: - layer_attrs_mapping: A mapping between attributes names. + layer_attrs_mapping (Union[Dict[str, str], None]): A mapping between attribute names. Returns: - QuantizationConfigOptions with edited attributes names. - + QuantizationConfigOptions: A new instance of QuantizationConfigOptions with updated attribute keys. """ - qc_options = copy.deepcopy(self) - - # Extract the list of existing quantization configurations from qc_options - - # Check if the base_config is already included in the quantization configuration list - # If not, add base_config to the list of configurations to update - cfgs_to_update = [cfg for cfg in qc_options.quantization_config_list] - if not any(qc_options.base_config is cfg for cfg in cfgs_to_update): - # TODO: add test for this case - cfgs_to_update.append(qc_options.base_config) - - for qc in cfgs_to_update: + updated_configs = [] + new_base_config = self.base_config + for qc in self.quantization_config_list: if layer_attrs_mapping is None: - qc.attr_weights_configs_mapping = {} - else: new_attr_mapping = {} - for attr in list(qc.attr_weights_configs_mapping.keys()): - new_key = layer_attrs_mapping.get(attr) - if new_key is None: # pragma: no cover - Logger.critical(f"Attribute \'{attr}\' does not exist in the provided attribute mapping.") - - new_attr_mapping[new_key] = qc.attr_weights_configs_mapping.pop(attr) - - qc.attr_weights_configs_mapping.update(new_attr_mapping) - - return qc_options + else: + new_attr_mapping = { + layer_attrs_mapping.get(attr, attr): cfg + for attr, cfg in qc.attr_weights_configs_mapping.items() + } + if qc == self.base_config: + new_base_config = replace(qc, attr_weights_configs_mapping=new_attr_mapping) + updated_configs.append(replace(qc, attr_weights_configs_mapping=new_attr_mapping)) + return replace(self, base_config=new_base_config, quantization_config_list=updated_configs) - def __edit_quantization_configuration(self, qc, kwargs): - for k, v in kwargs.items(): - assert hasattr(qc, - k), (f'Editing is only possible for existing attributes in the configuration; ' - f'{k} is not an attribute of {qc}.') - setattr(qc, k, v) + def get_info(self) -> Dict[str, Any]: + """ + Get detailed information about each quantization configuration option. - def get_info(self): + Returns: + dict: Information about the quantization configuration options as a dictionary. + """ return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_config_list)} +@dataclass(frozen=True) class TargetPlatformModelComponent: """ - Component of TargetPlatformModel (Fusing, OperatorsSet, etc.) + Component of TargetPlatformModel (Fusing, OperatorsSet, etc.). """ - def __init__(self, name: str): - """ - Args: - name: Name of component. + def __post_init__(self) -> None: + """ + Post-initialization to register the component with the current TargetPlatformModel. """ - self.name = name _current_tp_model.get().append_component(self) def get_info(self) -> Dict[str, Any]: """ + Get information about the component to display. - Returns: Get information about the component to display (return an empty dictionary. - the actual component should fill it with info). - + Returns: + Dict[str, Any]: Returns an empty dictionary. The actual component should override + this method to provide relevant information. """ return {} +@dataclass(frozen=True) class OperatorsSetBase(TargetPlatformModelComponent): """ - Base class to represent a set of operators. + Base class to represent a set of a target platform model component of operator set types. + Inherits from TargetPlatformModelComponent. """ - def __init__(self, name: str): + def __post_init__(self) -> None: """ - - Args: - name: Name of OperatorsSet. + Post-initialization to ensure the component is registered with the TargetPlatformModel. + Calls the parent class's __post_init__ method to append this component to the current TargetPlatformModel. """ - super().__init__(name=name) + super().__post_init__() +@dataclass(frozen=True) class OperatorsSet(OperatorsSetBase): - def __init__(self, - name: str, - qc_options: QuantizationConfigOptions = None): - """ - Set of operators that are represented by a unique label. + """ + Set of operators that are represented by a unique label. - Args: - name (str): Set's label (must be unique in a TargetPlatformModel). - qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations. - """ + Attributes: + name (str): The set's label (must be unique within a TargetPlatformModel). + qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations. + If None, it represents a fusing set. + is_default (bool): Indicates whether this set is the default quantization configuration + for the TargetPlatformModel or a fusing set. + """ + name: str + qc_options: QuantizationConfigOptions = None - super().__init__(name) - self.qc_options = qc_options - is_fusing_set = qc_options is None - self.is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set + def __post_init__(self) -> None: + """ + Post-initialization processing to mark the operator set as default if applicable. + Calls the parent class's __post_init__ method and sets `is_default` to True + if this set corresponds to the default quantization configuration for the + TargetPlatformModel or if it is a fusing set. - def get_info(self) -> Dict[str,Any]: """ + super().__post_init__() + is_fusing_set = self.qc_options is None + is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set + object.__setattr__(self, 'is_default', is_default) - Returns: Info about the set as a dictionary. + def get_info(self) -> Dict[str, Any]: + """ + Get information about the set as a dictionary. + Returns: + Dict[str, Any]: A dictionary containing the set name and + whether it is the default quantization configuration. """ return {"name": self.name, "is_default_qc": self.is_default} +@dataclass(frozen=True) class OperatorSetConcat(OperatorsSetBase): """ Concatenate a list of operator sets to treat them similarly in different places (like fusing). + + Attributes: + op_set_list (List[OperatorsSet]): List of operator sets to group. + qc_options (None): Configuration options for the set, always None for concatenated sets. + name (str): Concatenated name generated from the names of the operator sets in the list. """ - def __init__(self, *opsets: OperatorsSet): - """ - Group a list of operation sets. + op_set_list: List[OperatorsSet] = field(default_factory=list) + qc_options: None = field(default=None, init=False) + name: str = None - Args: - *opsets (OperatorsSet): List of operator sets to group. + def __post_init__(self) -> None: """ - name = "_".join([a.name for a in opsets]) - super().__init__(name=name) - self.op_set_list = opsets - self.qc_options = None # Concat have no qc options + Post-initialization processing to generate the concatenated name and set it as the `name` attribute. - def get_info(self) -> Dict[str,Any]: + Calls the parent class's __post_init__ method and creates a concatenated name + by joining the names of all operator sets in `op_set_list`. """ + super().__post_init__() + # Generate the concatenated name from the operator sets + concatenated_name = "_".join([op.name for op in self.op_set_list]) + # Set the inherited name attribute using `object.__setattr__` since the dataclass is frozen + object.__setattr__(self, "name", concatenated_name) - Returns: Info about the sets group as a dictionary. + def get_info(self) -> Dict[str, Any]: + """ + Get information about the concatenated set as a dictionary. + Returns: + Dict[str, Any]: A dictionary containing the concatenated name and + the list of names of the operator sets in `op_set_list`. """ return {"name": self.name, OPS_SET_LIST: [s.name for s in self.op_set_list]} +@dataclass(frozen=True) class Fusing(TargetPlatformModelComponent): """ - Fusing defines a list of operators that should be combined and treated as a single operator, - hence no quantization is applied between them. + Fusing defines a list of operators that should be combined and treated as a single operator, + hence no quantization is applied between them. + + Attributes: + operator_groups_list (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, + each being either an OperatorSetConcat or an OperatorsSet. + name (str): The name for the Fusing instance. If not provided, it is generated from the operator groups' names. """ + operator_groups_list: Tuple[Union[OperatorsSet, OperatorSetConcat]] + name: str = None - def __init__(self, - operator_groups_list: List[Union[OperatorsSet, OperatorSetConcat]], - name: str = None): - """ - Args: - operator_groups_list (List[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, each being either an OperatorSetConcat or an OperatorsSet. - name (str): The name for the Fusing instance. If not provided, it's generated from the operator groups' names. + def __post_init__(self) -> None: """ - assert isinstance(operator_groups_list, - list), f'List of operator groups should be of type list but is {type(operator_groups_list)}' - assert len(operator_groups_list) >= 2, f'Fusing can not be created for a single operators group' + Post-initialization processing for input validation and name generation. - # Generate a name from the operator groups if no name is provided - if name is None: - name = '_'.join([x.name for x in operator_groups_list]) + Calls the parent class's __post_init__ method, validates the operator_groups_list, + and generates the name if not explicitly provided. - super().__init__(name) - self.operator_groups_list = operator_groups_list + Raises: + Logger critical if operator_groups_list is not a list or if it contains fewer than two operators. + """ + super().__post_init__() + # Validate the operator_groups_list + if not isinstance(self.operator_groups_list, list): + Logger.critical( + f"List of operator groups should be of type list but is {type(self.operator_groups_list)}.") + if len(self.operator_groups_list) < 2: + Logger.critical("Fusing cannot be created for a single operator.") + + # if self.name is None: + # Generate the name from the operator groups if not provided + generated_name = '_'.join([x.name for x in self.operator_groups_list]) + object.__setattr__(self, 'name', generated_name) def contains(self, other: Any) -> bool: """ Determines if the current Fusing instance contains another Fusing instance. Args: - other: The other Fusing instance to check against. + other (Any): The other Fusing instance to check against. Returns: - A boolean indicating whether the other instance is contained within this one. + bool: True if the other Fusing instance is contained within this one, False otherwise. """ if not isinstance(other, Fusing): return False @@ -506,81 +535,72 @@ def contains(self, other: Any) -> bool: # Other Fusing instance is not contained return False - def get_info(self): + def get_info(self) -> Union[Dict[str, str], str]: """ Retrieves information about the Fusing instance, including its name and the sequence of operator groups. Returns: - A dictionary with the Fusing instance's name as the key and the sequence of operator groups as the value, - or just the sequence of operator groups if no name is set. + Union[Dict[str, str], str]: A dictionary with the Fusing instance's name as the key + and the sequence of operator groups as the value, + or just the sequence of operator groups if no name is set. """ if self.name is not None: return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])} return ' -> '.join([x.name for x in self.operator_groups_list]) -class TargetPlatformModel(ImmutableClass): +@dataclass(frozen=True) +class TargetPlatformModel: """ Represents the hardware configuration used for quantized model inference. - This model defines: - - The operators and their associated quantization configurations. - - Fusing patterns, enabling multiple operators to be combined into a single operator - for optimization during inference. - - Versioning support through minor and patch versions for backward compatibility. - Attributes: - SCHEMA_VERSION (int): The schema version of the target platform model. + default_qco (QuantizationConfigOptions): Default quantization configuration options for the model. + tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration. + tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration. + tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration. + add_metadata (bool): Flag to determine if metadata should be added. + name (str): Name of the Target Platform Model. + operator_set (List[OperatorsSetBase]): List of operator sets within the model. + fusing_patterns (List[Fusing]): List of fusing patterns for the model. + is_simd_padding (bool): Indicates if SIMD padding is applied. + SCHEMA_VERSION (int): Version of the schema for the Target Platform Model. """ - SCHEMA_VERSION = 1 - def __init__(self, - default_qco: QuantizationConfigOptions, - tpc_minor_version: Optional[int], - tpc_patch_version: Optional[int], - tpc_platform_type: Optional[str], - add_metadata: bool = True, - name="default_tp_model"): + default_qco: QuantizationConfigOptions + tpc_minor_version: Optional[int] + tpc_patch_version: Optional[int] + tpc_platform_type: Optional[str] + add_metadata: bool = True + name: str = "default_tp_model" + operator_set: List[OperatorsSetBase] = field(default_factory=list) + fusing_patterns: List[Fusing] = field(default_factory=list) + is_simd_padding: bool = False + + SCHEMA_VERSION: int = 1 + + def __post_init__(self) -> None: """ + Post-initialization processing for input validation. - Args: - default_qco (QuantizationConfigOptions): Default QuantizationConfigOptions to use for operators that their QuantizationConfigOptions are not defined in the model. - tpc_minor_version (Optional[int]): The minor version of the target platform capabilities. - tpc_patch_version (Optional[int]): The patch version of the target platform capabilities. - tpc_platform_type (Optional[str]): The platform type of the target platform capabilities. - add_metadata (bool): Whether to add metadata to the model or not. - name (str): Name of the model. - - Raises: - AssertionError: If the provided `default_qco` does not contain exactly one quantization configuration. - """ - - super().__init__() - self.tpc_minor_version = tpc_minor_version - self.tpc_patch_version = tpc_patch_version - self.tpc_platform_type = tpc_platform_type - self.add_metadata = add_metadata - self.name = name - self.operator_set = [] - assert isinstance(default_qco, QuantizationConfigOptions), \ - "default_qco must be an instance of QuantizationConfigOptions" - assert len(default_qco.quantization_config_list) == 1, \ - "Default QuantizationConfigOptions must contain exactly one option." - - self.default_qco = default_qco - self.fusing_patterns = [] - self.is_simd_padding = False - - def get_config_options_by_operators_set(self, - operators_set_name: str) -> QuantizationConfigOptions: - """ - Get the QuantizationConfigOptions of a OperatorsSet by the OperatorsSet name. - If the name is not in the model, the default QuantizationConfigOptions is returned. + Raises: + Logger critical if the default_qco is not an instance of QuantizationConfigOptions + or if it contains more than one quantization configuration. + """ + # Validate `default_qco` + if not isinstance(self.default_qco, QuantizationConfigOptions): + Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") + if len(self.default_qco.quantization_config_list) != 1: + Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") + + def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions: + """ + Get the QuantizationConfigOptions of an OperatorsSet by its name. Args: - operators_set_name: Name of OperatorsSet to get. + operators_set_name (str): Name of the OperatorsSet to get. Returns: - QuantizationConfigOptions to use for ops in OperatorsSet named operators_set_name. + QuantizationConfigOptions: Quantization configuration options for the given OperatorsSet name. """ for op_set in self.operator_set: if operators_set_name == op_set.name: @@ -589,143 +609,114 @@ def get_config_options_by_operators_set(self, def get_default_op_quantization_config(self) -> OpQuantizationConfig: """ + Get the default OpQuantizationConfig of the TargetPlatformModel. - Returns: The default OpQuantizationConfig of the TargetPlatformModel. - + Returns: + OpQuantizationConfig: The default quantization configuration. """ assert len(self.default_qco.quantization_config_list) == 1, \ - f'Default quantization configuration options must contain only one option,' \ - f' but found {len(get_current_tp_model().default_qco.quantization_config_list)} configurations.' + f"Default quantization configuration options must contain only one option, " \ + f"but found {len(self.default_qco.quantization_config_list)} configurations." return self.default_qco.quantization_config_list[0] - def is_opset_in_model(self, - opset_name: str) -> bool: + def is_opset_in_model(self, opset_name: str) -> bool: """ - Check whether an operators set is defined in the model or not. + Check whether an OperatorsSet is defined in the model. Args: - opset_name: Operators set name to check. + opset_name (str): Name of the OperatorsSet to check. Returns: - Whether an operators set is defined in the model or not. + bool: True if the OperatorsSet exists, False otherwise. """ return opset_name in [x.name for x in self.operator_set] - def get_opset_by_name(self, - opset_name: str) -> OperatorsSetBase: + def get_opset_by_name(self, opset_name: str) -> Optional[OperatorsSetBase]: """ Get an OperatorsSet object from the model by its name. - If name is not in the model - None is returned. Args: - opset_name: OperatorsSet name to retrieve. + opset_name (str): Name of the OperatorsSet to retrieve. Returns: - OperatorsSet object with the name opset_name, or None if opset_name is not in the model. + Optional[OperatorsSetBase]: The OperatorsSet object with the given name, + or None if not found in the model. """ - opset_list = [x for x in self.operator_set if x.name == opset_name] - assert len(opset_list) <= 1, f'Found more than one OperatorsSet in' \ - f' TargetPlatformModel with the name {opset_name}. ' \ - f'OperatorsSet name must be unique.' - if len(opset_list) == 0: # opset_name is not in the model. - return None + if len(opset_list) > 1: + Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") + return opset_list[0] if opset_list else None - return opset_list[0] # There's one opset with that name - - def append_component(self, - tp_model_component: TargetPlatformModelComponent): + def append_component(self, tp_model_component: TargetPlatformModelComponent) -> None: """ - Attach a TargetPlatformModel component to the model. Components can be for example: - Fusing, OperatorsSet, etc. + Attach a TargetPlatformModel component to the model (like Fusing or OperatorsSet). Args: - tp_model_component: Component to attach to the model. + tp_model_component (TargetPlatformModelComponent): Component to attach to the model. + Raises: + Logger critical if the component is not an instance of Fusing or OperatorsSetBase. """ if isinstance(tp_model_component, Fusing): self.fusing_patterns.append(tp_model_component) elif isinstance(tp_model_component, OperatorsSetBase): self.operator_set.append(tp_model_component) else: # pragma: no cover - Logger.critical(f'Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.') - - def __enter__(self): - """ - Start defining the TargetPlatformModel using 'with'. - - Returns: Initialized TargetPlatformModel object. + Logger.critical( + f"Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.") + def get_info(self) -> Dict[str, Any]: """ - _current_tp_model.set(self) - return self + Get a dictionary summarizing the TargetPlatformModel properties. - def __exit__(self, exc_type, exc_value, tb): - """ - Finish defining the TargetPlatformModel at the end of the 'with' clause. - Returns the final and immutable TargetPlatformModel instance. + Returns: + Dict[str, Any]: Summary of the TargetPlatformModel properties. """ + return { + "Model name": self.name, + "Default quantization config": self.get_default_op_quantization_config().get_info(), + "Operators sets": [o.get_info() for o in self.operator_set], + "Fusing patterns": [f.get_info() for f in self.fusing_patterns], + } - if exc_value is not None: - print(exc_value, exc_value.args) - raise exc_value - self.__validate_model() # Assert that model is valid. - _current_tp_model.reset() - self.initialized_done() # Make model immutable. - return self - - def __validate_model(self): + def __validate_model(self) -> None: """ + Validate the model's configuration to ensure its integrity. - Assert model is valid. - Model is invalid if, for example, it contains multiple operator sets with the same name, - as their names should be unique. - + Raises: + Logger critical if the model contains multiple operator sets with the same name. """ opsets_names = [op.name for op in self.operator_set] if len(set(opsets_names)) != len(opsets_names): - Logger.critical(f'Operator Sets must have unique names.') + Logger.critical("Operator Sets must have unique names.") - def get_default_config(self) -> OpQuantizationConfig: + def __enter__(self) -> 'TargetPlatformModel': """ + Start defining the TargetPlatformModel using a 'with' statement. Returns: - - """ - assert len(self.default_qco.quantization_config_list) == 1, \ - f'Default quantization configuration options must contain only one option,' \ - f' but found {len(self.default_qco.quantization_config_list)} configurations.' - return self.default_qco.quantization_config_list[0] - - def get_info(self) -> Dict[str, Any]: + TargetPlatformModel: The initialized TargetPlatformModel object. """ + _current_tp_model.set(self) + return self - Returns: Dictionary that summarizes the TargetPlatformModel properties (for display purposes). - - """ - return {"Model name": self.name, - "Default quantization config": self.get_default_config().get_info(), - "Operators sets": [o.get_info() for o in self.operator_set], - "Fusing patterns": [f.get_info() for f in self.fusing_patterns] - } - - def show(self): - """ - - Display the TargetPlatformModel. - - """ - pprint.pprint(self.get_info(), sort_dicts=False) - - def set_simd_padding(self, - is_simd_padding: bool): + def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel': """ - Set flag is_simd_padding to indicate whether this TP model defines - that padding due to SIMD constrains occurs. + Finalize and validate the TargetPlatformModel at the end of the 'with' clause. Args: - is_simd_padding: Whether this TP model defines that padding due to SIMD constrains occurs. + exc_type: Exception type, if any occurred. + exc_value: Exception value, if any occurred. + tb: Traceback object, if an exception occurred. - """ - self.is_simd_padding = is_simd_padding + Raises: + The exception raised in the 'with' block, if any. + Returns: + TargetPlatformModel: The validated TargetPlatformModel object. + """ + if exc_value is not None: + raise exc_value + self.__validate_model() + _current_tp_model.reset() + return self diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py index 27d032c29..f9e94f81d 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py @@ -164,7 +164,8 @@ def generate_tp_model(default_config: OpQuantizationConfig, tpc_patch_version=0, tpc_platform_type=IMX500_TP_MODEL, name=name, - add_metadata=False) + add_metadata=False, + is_simd_padding=True) # To start defining the model's components (such as operator sets, and fusing patterns), # use 'with' the TargetPlatformModel instance, and create them as below: @@ -175,8 +176,6 @@ def generate_tp_model(default_config: OpQuantizationConfig, # be used for operations that will be attached to this set's label. # Otherwise, it will be a configure-less set (used in fusing): - generated_tpc.set_simd_padding(is_simd_padding=True) - # May suit for operations like: Dropout, Reshape, etc. default_qco = tp.get_default_quantization_config_options() schema.OperatorsSet("NoQuantization", @@ -206,9 +205,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py index 9da497022..707fa76e1 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py @@ -201,9 +201,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py index 24f3e6eae..032a42c6a 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py @@ -197,9 +197,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py index 947c1608f..ae7056b99 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py @@ -166,7 +166,8 @@ def generate_tp_model(default_config: OpQuantizationConfig, tpc_patch_version=0, tpc_platform_type=IMX500_TP_MODEL, add_metadata=True, - name=name) + name=name, + is_simd_padding=True) # To start defining the model's components (such as operator sets, and fusing patterns), # use 'with' the TargetPlatformModel instance, and create them as below: @@ -177,8 +178,6 @@ def generate_tp_model(default_config: OpQuantizationConfig, # be used for operations that will be attached to this set's label. # Otherwise, it will be a configure-less set (used in fusing): - generated_tpm.set_simd_padding(is_simd_padding=True) - # May suit for operations like: Dropout, Reshape, etc. default_qco = tp.get_default_quantization_config_options() schema.OperatorsSet("NoQuantization", @@ -208,9 +207,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py index 31ba2d9ab..187ef1100 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py @@ -203,9 +203,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py index b053ea9eb..5e07cb7d9 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py @@ -187,7 +187,8 @@ def generate_tp_model(default_config: OpQuantizationConfig, tpc_patch_version=0, tpc_platform_type=IMX500_TP_MODEL, add_metadata=True, - name=name) + name=name, + is_simd_padding=True) # To start defining the model's components (such as operator sets, and fusing patterns), # use 'with' the TargetPlatformModel instance, and create them as below: @@ -198,8 +199,6 @@ def generate_tp_model(default_config: OpQuantizationConfig, # be used for operations that will be attached to this set's label. # Otherwise, it will be a configure-less set (used in fusing): - generated_tpm.set_simd_padding(is_simd_padding=True) - # May suit for operations like: Dropout, Reshape, etc. default_qco = tp.get_default_quantization_config_options() schema.OperatorsSet("NoQuantization", @@ -231,9 +230,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py index 9102fcc02..8b25c33c2 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py @@ -214,9 +214,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py index 7c056778e..2f658d2f8 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py @@ -15,7 +15,7 @@ from typing import List, Tuple import model_compression_toolkit as mct -import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema +import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS, \ IMX500_TP_MODEL @@ -235,7 +235,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, tpc_minor_version=4, tpc_patch_version=0, tpc_platform_type=IMX500_TP_MODEL, - add_metadata=True, name=name) + add_metadata=True, + name=name, + is_simd_padding=True) # To start defining the model's components (such as operator sets, and fusing patterns), # use 'with' the TargetPlatformModel instance, and create them as below: @@ -246,8 +248,6 @@ def generate_tp_model(default_config: OpQuantizationConfig, # be used for operations that will be attached to this set's label. # Otherwise, it will be a configure-less set (used in fusing): - generated_tpm.set_simd_padding(is_simd_padding=True) - # May suit for operations like: Dropout, Reshape, etc. default_qco = tp.get_default_quantization_config_options() schema.OperatorsSet(OPSET_NO_QUANTIZATION, @@ -294,11 +294,11 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, - tanh, gelu, hardswish, hardsigmoid) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid, tanh, gelu, - hardswish, hardsigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, + tanh, gelu, hardswish, hardsigmoid]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid, tanh, gelu, + hardswish, hardsigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py index b0a69c6e7..d269d7f4e 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py @@ -180,11 +180,11 @@ def generate_tp_model(default_config: OpQuantizationConfig, fixed_zero_point=-128, fixed_scale=1 / 256)) conv2d = schema.OperatorsSet("Conv2d") - kernel = schema.OperatorSetConcat(conv2d, fc) + kernel = schema.OperatorSetConcat([conv2d, fc]) relu = schema.OperatorsSet("Relu") elu = schema.OperatorsSet("Elu") - activations_to_fuse = schema.OperatorSetConcat(relu, elu) + activations_to_fuse = schema.OperatorSetConcat([relu, elu]) batch_norm = schema.OperatorsSet("BatchNorm") bias_add = schema.OperatorsSet("BiasAdd") diff --git a/tests/common_tests/test_tp_model.py b/tests/common_tests/test_tp_model.py index ed8a52b59..5fea3155b 100644 --- a/tests/common_tests/test_tp_model.py +++ b/tests/common_tests/test_tp_model.py @@ -55,7 +55,7 @@ def test_immutable_tp(self): with model: schema.OperatorsSet("opset") model.operator_set = [] - self.assertEqual('Immutable class. Can\'t edit attributes.', str(e.exception)) + self.assertEqual("cannot assign to field 'operator_set'", str(e.exception)) def test_default_options_more_than_single_qc(self): test_qco = schema.QuantizationConfigOptions([TEST_QC, TEST_QC], base_config=TEST_QC) @@ -76,8 +76,6 @@ def test_tp_model_show(self): with tpm: a = schema.OperatorsSet("opA") - tpm.show() - class OpsetTest(unittest.TestCase): @@ -114,7 +112,7 @@ def test_opset_concat(self): b = schema.OperatorsSet('opset_B', get_default_quantization_config_options().clone_and_edit(activation_n_bits=2)) schema.OperatorsSet('opset_C') # Just add it without using it in concat - schema.OperatorSetConcat(a, b) + schema.OperatorSetConcat([a, b]) self.assertEqual(len(hm.operator_set), 4) self.assertTrue(hm.is_opset_in_model("opset_A_opset_B")) self.assertTrue(hm.get_config_options_by_operators_set('opset_A_opset_B') is None) @@ -136,14 +134,14 @@ def test_non_unique_opset(self): class QCOptionsTest(unittest.TestCase): def test_empty_qc_options(self): - with self.assertRaises(AssertionError) as e: + with self.assertRaises(Exception) as e: schema.QuantizationConfigOptions([]) self.assertEqual( "'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.", str(e.exception)) def test_list_of_no_qc(self): - with self.assertRaises(AssertionError) as e: + with self.assertRaises(Exception) as e: schema.QuantizationConfigOptions([TEST_QC, 3]) self.assertEqual( 'Each option must be an instance of \'OpQuantizationConfig\', but found an object of type: .', @@ -186,7 +184,7 @@ def test_fusing_single_opset(self): add = schema.OperatorsSet("add") with self.assertRaises(Exception) as e: schema.Fusing([add]) - self.assertEqual('Fusing can not be created for a single operators group', str(e.exception)) + self.assertEqual('Fusing cannot be created for a single operator.', str(e.exception)) def test_fusing_contains(self): hm = schema.TargetPlatformModel( @@ -220,7 +218,7 @@ def test_fusing_contains_with_opset_concat(self): conv = schema.OperatorsSet("conv") add = schema.OperatorsSet("add") tanh = schema.OperatorsSet("tanh") - add_tanh = schema.OperatorSetConcat(add, tanh) + add_tanh = schema.OperatorSetConcat([add, tanh]) schema.Fusing([conv, add]) schema.Fusing([conv, add_tanh]) schema.Fusing([conv, add, tanh]) diff --git a/tests/keras_tests/exporter_tests/tflite_int8/imx500_int8_tp_model.py b/tests/keras_tests/exporter_tests/tflite_int8/imx500_int8_tp_model.py index 8e8f2eac4..209287fbf 100644 --- a/tests/keras_tests/exporter_tests/tflite_int8/imx500_int8_tp_model.py +++ b/tests/keras_tests/exporter_tests/tflite_int8/imx500_int8_tp_model.py @@ -95,9 +95,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, swish = schema.OperatorsSet("Swish") sigmoid = schema.OperatorsSet("Sigmoid") tanh = schema.OperatorsSet("Tanh") - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) schema.Fusing([conv, activations_after_conv_to_fuse]) schema.Fusing([fc, activations_after_fc_to_fuse]) schema.Fusing([any_binary, any_relu]) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py index 7b4e86d05..2218a8d16 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import replace + import numpy as np import tensorflow as tf @@ -34,8 +36,8 @@ def get_tpc(self): tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v4') # Force Mul base_config to 16bit only mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config) return tpc def create_networks(self): @@ -67,8 +69,8 @@ class Activation16BitMixedPrecisionTest(Activation16BitTest): def get_tpc(self): tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v3') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config) mul_op_set.qc_options.quantization_config_list.extend( [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)]) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py b/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py index 9c35e1582..243316a21 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import replace + import numpy as np import tensorflow as tf @@ -133,9 +135,8 @@ def get_tpc(self): tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v3') # Force Mul base_config to 16bit only mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = \ - [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config) return tpc def create_networks(self): @@ -159,9 +160,8 @@ class Manual16BitWidthSelectionMixedPrecisionTest(Manual16BitWidthSelectionTest) def get_tpc(self): tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v3') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = \ - [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config) mul_op_set.qc_options.quantization_config_list.extend( [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)]) diff --git a/tests/keras_tests/function_tests/test_layer_fusing.py b/tests/keras_tests/function_tests/test_layer_fusing.py index 0c8a5b2e6..f55c31d4f 100644 --- a/tests/keras_tests/function_tests/test_layer_fusing.py +++ b/tests/keras_tests/function_tests/test_layer_fusing.py @@ -120,7 +120,7 @@ def get_tpc_2(): swish = schema.OperatorsSet("Swish") sigmoid = schema.OperatorsSet("Sigmoid") tanh = schema.OperatorsSet("Tanh") - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid, tanh) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid, tanh]) # Define fusions schema.Fusing([conv, activations_after_conv_to_fuse]) @@ -161,7 +161,7 @@ def get_tpc_4(): any_relu = schema.OperatorsSet("AnyReLU") add = schema.OperatorsSet("Add") swish = schema.OperatorsSet("Swish") - activations_to_fuse = schema.OperatorSetConcat(any_relu, swish) + activations_to_fuse = schema.OperatorSetConcat([any_relu, swish]) # Define fusions schema.Fusing([conv, activations_to_fuse]) schema.Fusing([conv, add, activations_to_fuse]) diff --git a/tests/keras_tests/function_tests/test_quant_config_filtering.py b/tests/keras_tests/function_tests/test_quant_config_filtering.py index c9365c103..6e5c3c871 100644 --- a/tests/keras_tests/function_tests/test_quant_config_filtering.py +++ b/tests/keras_tests/function_tests/test_quant_config_filtering.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +from dataclasses import replace import unittest import numpy as np @@ -44,8 +44,8 @@ def get_tpc_default_16bit(): tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v3') # Force Mul base_config to 16bit only mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config) return tpc def test_config_filtering(self): diff --git a/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py b/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py index 6f4478aff..add49fd26 100644 --- a/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py +++ b/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py @@ -130,7 +130,7 @@ def test_get_layers_by_opconcat(self): with hm: op_obj_a = schema.OperatorsSet('opsetA') op_obj_b = schema.OperatorsSet('opsetB') - op_concat = schema.OperatorSetConcat(op_obj_a, op_obj_b) + op_concat = schema.OperatorSetConcat([op_obj_a, op_obj_b]) fw_tp = TargetPlatformCapabilities(hm) with fw_tp: diff --git a/tests/pytorch_tests/function_tests/layer_fusing_test.py b/tests/pytorch_tests/function_tests/layer_fusing_test.py index ccf131ddd..6ecdca713 100644 --- a/tests/pytorch_tests/function_tests/layer_fusing_test.py +++ b/tests/pytorch_tests/function_tests/layer_fusing_test.py @@ -229,7 +229,7 @@ def get_tpc(self): any_relu = schema.OperatorsSet("AnyReLU") add = schema.OperatorsSet("Add") swish = schema.OperatorsSet("Swish") - activations_to_fuse = schema.OperatorSetConcat(any_relu, swish) + activations_to_fuse = schema.OperatorSetConcat([any_relu, swish]) # Define fusions schema.Fusing([conv, activations_to_fuse]) schema.Fusing([conv, add, activations_to_fuse]) diff --git a/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py b/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py index cb7c7647d..68c597f13 100644 --- a/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py +++ b/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py @@ -168,7 +168,7 @@ def test_get_layers_by_opconcat(self): with hm: op_obj_a = schema.OperatorsSet('opsetA') op_obj_b = schema.OperatorsSet('opsetB') - op_concat = schema.OperatorSetConcat(op_obj_a, op_obj_b) + op_concat = schema.OperatorSetConcat([op_obj_a, op_obj_b]) fw_tp = TargetPlatformCapabilities(hm) with fw_tp: diff --git a/tests/pytorch_tests/function_tests/test_quant_config_filtering.py b/tests/pytorch_tests/function_tests/test_quant_config_filtering.py index e2754302e..d26bfe3f9 100644 --- a/tests/pytorch_tests/function_tests/test_quant_config_filtering.py +++ b/tests/pytorch_tests/function_tests/test_quant_config_filtering.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +from dataclasses import replace import unittest import model_compression_toolkit as mct @@ -34,8 +34,8 @@ def get_tpc_default_16bit(): tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v3') # Force Mul base_config to 16bit only mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[torch.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[torch.multiply] = replace(tpc.layer2qco[torch.multiply], base_config=base_config) return tpc def test_config_filtering(self): diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py index cfc6fa2e8..6d2196053 100644 --- a/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import replace + from operator import mul import torch @@ -62,10 +64,9 @@ def forward(self, x): def set_16bit_as_default(tpc, required_op_set, required_ops_list): - op_set = get_op_set(required_op_set, tpc.tp_model.operator_set) - op_set.qc_options.base_config = [l for l in op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] for op in required_ops_list: - tpc.layer2qco[op].base_config = [l for l in tpc.layer2qco[op].quantization_config_list if l.activation_n_bits == 16][0] + base_config = [l for l in tpc.layer2qco[op].quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[op] = replace(tpc.layer2qco[op], base_config=base_config) class Activation16BitTest(BasePytorchFeatureNetworkTest): @@ -106,9 +107,9 @@ class Activation16BitMixedPrecisionTest(Activation16BitTest): def get_tpc(self): tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v3') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[torch.mul].base_config = mul_op_set.qc_options.base_config - tpc.layer2qco[mul].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[torch.mul] = replace(tpc.layer2qco[torch.mul], base_config=base_config) + tpc.layer2qco[mul] = replace(tpc.layer2qco[mul], base_config=base_config) mul_op_set.qc_options.quantization_config_list.extend( [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)]) diff --git a/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py b/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py index 8d2207974..3178785f2 100644 --- a/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py +++ b/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import replace + from operator import mul import inspect @@ -186,9 +188,9 @@ class Manual16BitTest(ManualBitWidthByLayerNameTest): def get_tpc(self): tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v3') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[torch.mul].base_config = mul_op_set.qc_options.base_config - tpc.layer2qco[mul].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[torch.mul] = replace(tpc.layer2qco[torch.mul], base_config=base_config) + tpc.layer2qco[mul] = replace(tpc.layer2qco[mul] , base_config=base_config) return {'mixed_precision_activation_model': tpc} def create_feature_network(self, input_shape): @@ -200,9 +202,9 @@ class Manual16BitTestMixedPrecisionTest(ManualBitWidthByLayerNameTest): def get_tpc(self): tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v3') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[torch.mul].base_config = mul_op_set.qc_options.base_config - tpc.layer2qco[mul].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[torch.mul] = replace(tpc.layer2qco[torch.mul], base_config=base_config) + tpc.layer2qco[mul] = replace(tpc.layer2qco[mul], base_config=base_config) mul_op_set.qc_options.quantization_config_list.extend( [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)])