diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index ecad233dd..b47a03f1f 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -24,6 +24,7 @@ from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \ OpQuantizationConfig +from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams @@ -585,7 +586,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities, _node_qc_options = node_qc_options.quantization_config_list if len(next_nodes): next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes] - next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits + next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg) for qc_opts in next_nodes_qc_options for op_cfg in qc_opts.quantization_config_list]) @@ -596,7 +597,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities, Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover # Verify base config match - if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits + if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config) for qc_opt in next_nodes_qc_options]): # base_config activation bits doesn't match next node supported input bit-width -> replace with # a qco from quantization_config_list with maximum activation bit-width. diff --git a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py index 2e6cc8d9d..5d4d18441 100644 --- a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py @@ -32,6 +32,7 @@ get_activation_quantization_params_fn, get_weights_quantization_params_fn from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \ get_weights_quantization_fn +from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \ QuantizationConfigOptions @@ -117,7 +118,7 @@ def filter_node_qco_by_graph(node: BaseNode, if len(next_nodes): next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes] - next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits + next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg) for qc_opts in next_nodes_qc_options for op_cfg in qc_opts.quantization_config_list]) @@ -128,7 +129,7 @@ def filter_node_qco_by_graph(node: BaseNode, Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.") # Verify base config match - if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits + if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config) for qc_opt in next_nodes_qc_options]): # base_config activation bits doesn't match next node supported input bit-width -> replace with # a qco from quantization_config_list with maximum activation bit-width. diff --git a/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py b/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py index 105136647..03b26e2d9 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from logging import Logger import copy -from typing import Any, Dict +from typing import Any, Dict, Optional + +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \ + TargetPlatformModel, QuantizationConfigOptions, OperatorsSetBase def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any: @@ -35,3 +39,106 @@ def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any: f'but {k} is not a parameter of {obj_copy}.' setattr(obj_copy, k, v) return obj_copy + + +def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions: + """ + Get the QuantizationConfigOptions of an OperatorsSet by its name. + + Args: + operators_set_name (str): Name of the OperatorsSet to get. + + Returns: + QuantizationConfigOptions: Quantization configuration options for the given OperatorsSet name. + """ + for op_set in self.operator_set: + if operators_set_name == op_set.name: + return op_set.qc_options + return self.default_qco + + +def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int: + """ + Get the maximum supported input bit-width. + + Args: + op_quantization_config (OpQuantizationConfig): The configuration object from which to retrieve the maximum supported input bit-width. + + Returns: + int: Maximum supported input bit-width. + """ + return max(op_quantization_config.supported_input_activation_n_bits) + + +def get_config_options_by_operators_set(tp_model: TargetPlatformModel, + operators_set_name: str) -> QuantizationConfigOptions: + """ + Get the QuantizationConfigOptions of an OperatorsSet by its name. + + Args: + tp_model (TargetPlatformModel): The target platform model containing the operator sets and their configurations. + operators_set_name (str): The name of the OperatorsSet whose quantization configuration options are to be retrieved. + + Returns: + QuantizationConfigOptions: The quantization configuration options associated with the specified OperatorsSet, + or the default quantization configuration options if the OperatorsSet is not found. + """ + for op_set in tp_model.operator_set: + if operators_set_name == op_set.name: + return op_set.qc_options + return tp_model.default_qco + + +def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuantizationConfig: + """ + Get the default OpQuantizationConfig of the TargetPlatformModel. + + Args: + tp_model (TargetPlatformModel): The target platform model containing the default quantization configuration. + + Returns: + OpQuantizationConfig: The default quantization configuration. + + Raises: + AssertionError: If the default quantization configuration list contains more than one configuration option. + """ + assert len(tp_model.default_qco.quantization_config_list) == 1, \ + f"Default quantization configuration options must contain only one option, " \ + f"but found {len(tp_model.default_qco.quantization_config_list)} configurations." + return tp_model.default_qco.quantization_config_list[0] + + +def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool: + """ + Check whether an OperatorsSet is defined in the model. + + Args: + tp_model (TargetPlatformModel): The target platform model containing the list of operator sets. + opset_name (str): The name of the OperatorsSet to check for existence. + + Returns: + bool: True if an OperatorsSet with the given name exists in the target platform model, + otherwise False. + """ + return opset_name in [x.name for x in tp_model.operator_set] + + +def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optional[OperatorsSetBase]: + """ + Get an OperatorsSet object from the model by its name. + + Args: + tp_model (TargetPlatformModel): The target platform model containing the list of operator sets. + opset_name (str): The name of the OperatorsSet to be retrieved. + + Returns: + Optional[OperatorsSetBase]: The OperatorsSet object with the specified name if found. + If no operator set with the specified name is found, None is returned. + + Raises: + A critical log message if multiple operator sets with the same name are found. + """ + opset_list = [x for x in tp_model.operator_set if x.name == opset_name] + if len(opset_list) > 1: + Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") + return opset_list[0] if opset_list else None diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py index be4954b7a..4353a7d98 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py @@ -103,14 +103,14 @@ class AttributeQuantizationConfig: weights_n_bits (int): Number of bits to quantize the coefficients. weights_per_channel_threshold (bool): Indicates whether to quantize the weights per-channel or per-tensor. enable_weights_quantization (bool): Indicates whether to quantize the model weights or not. - lut_values_bitwidth (Union[int, None]): Number of bits to use when quantizing in a look-up table. + lut_values_bitwidth (Optional[int]): Number of bits to use when quantizing in a look-up table. If None, defaults to 8 in hptq; otherwise, it uses the provided value. """ weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO weights_n_bits: int = FLOAT_BITWIDTH weights_per_channel_threshold: bool = False enable_weights_quantization: bool = False - lut_values_bitwidth: Union[int, None] = None + lut_values_bitwidth: Optional[int] = None def __post_init__(self): """ @@ -170,7 +170,7 @@ class OpQuantizationConfig: simd_size: int signedness: Signedness - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing for input validation. @@ -218,16 +218,6 @@ def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs) # Return a new instance with the updated attribute mapping return replace(updated_config, attr_weights_configs_mapping=updated_attr_mapping) - @property - def max_input_activation_n_bits(self) -> int: - """ - Get the maximum supported input bit-width. - - Returns: - int: Maximum supported input bit-width. - """ - return max(self.supported_input_activation_n_bits) - @dataclass(frozen=True) class QuantizationConfigOptions: @@ -236,12 +226,12 @@ class QuantizationConfigOptions: Attributes: quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather. - base_config (Union[OpQuantizationConfig, None]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner. + base_config (Optional[OpQuantizationConfig]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner. """ quantization_config_list: List[OpQuantizationConfig] - base_config: Union[OpQuantizationConfig, None] = None + base_config: Optional[OpQuantizationConfig] = None - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing for input validation. @@ -320,12 +310,12 @@ def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) -> updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping)) return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs) - def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Union[Dict[str, str], None]) -> 'QuantizationConfigOptions': + def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Optional[Dict[str, str]]) -> 'QuantizationConfigOptions': """ Clones the quantization configurations and updates keys in attribute config mappings. Args: - layer_attrs_mapping (Union[Dict[str, str], None]): A mapping between attribute names. + layer_attrs_mapping (Optional[Dict[str, str]]): A mapping between attribute names. Returns: QuantizationConfigOptions: A new instance of QuantizationConfigOptions with updated attribute keys. @@ -361,7 +351,7 @@ class TargetPlatformModelComponent: Component of TargetPlatformModel (Fusing, OperatorsSet, etc.). """ - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization to register the component with the current TargetPlatformModel. """ @@ -384,7 +374,7 @@ class OperatorsSetBase(TargetPlatformModelComponent): Base class to represent a set of a target platform model component of operator set types. Inherits from TargetPlatformModelComponent. """ - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization to ensure the component is registered with the TargetPlatformModel. Calls the parent class's __post_init__ method to append this component to the current TargetPlatformModel. @@ -407,7 +397,7 @@ class OperatorsSet(OperatorsSetBase): name: str qc_options: QuantizationConfigOptions = None - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing to mark the operator set as default if applicable. @@ -447,7 +437,7 @@ class OperatorSetConcat(OperatorsSetBase): qc_options: None = field(default=None, init=False) name: str = None - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing to generate the concatenated name and set it as the `name` attribute. @@ -486,7 +476,7 @@ class Fusing(TargetPlatformModelComponent): operator_groups_list: Tuple[Union[OperatorsSet, OperatorSetConcat]] name: str = None - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing for input validation and name generation. @@ -504,7 +494,6 @@ def __post_init__(self) -> None: if len(self.operator_groups_list) < 2: Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover - # if self.name is None: # Generate the name from the operator groups if not provided generated_name = '_'.join([x.name for x in self.operator_groups_list]) object.__setattr__(self, 'name', generated_name) @@ -578,7 +567,7 @@ class TargetPlatformModel: SCHEMA_VERSION: int = 1 - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing for input validation. @@ -592,62 +581,7 @@ def __post_init__(self) -> None: if len(self.default_qco.quantization_config_list) != 1: Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover - def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions: - """ - Get the QuantizationConfigOptions of an OperatorsSet by its name. - - Args: - operators_set_name (str): Name of the OperatorsSet to get. - - Returns: - QuantizationConfigOptions: Quantization configuration options for the given OperatorsSet name. - """ - for op_set in self.operator_set: - if operators_set_name == op_set.name: - return op_set.qc_options - return self.default_qco - - def get_default_op_quantization_config(self) -> OpQuantizationConfig: - """ - Get the default OpQuantizationConfig of the TargetPlatformModel. - - Returns: - OpQuantizationConfig: The default quantization configuration. - """ - assert len(self.default_qco.quantization_config_list) == 1, \ - f"Default quantization configuration options must contain only one option, " \ - f"but found {len(self.default_qco.quantization_config_list)} configurations." - return self.default_qco.quantization_config_list[0] - - def is_opset_in_model(self, opset_name: str) -> bool: - """ - Check whether an OperatorsSet is defined in the model. - - Args: - opset_name (str): Name of the OperatorsSet to check. - - Returns: - bool: True if the OperatorsSet exists, False otherwise. - """ - return opset_name in [x.name for x in self.operator_set] - - def get_opset_by_name(self, opset_name: str) -> Optional[OperatorsSetBase]: - """ - Get an OperatorsSet object from the model by its name. - - Args: - opset_name (str): Name of the OperatorsSet to retrieve. - - Returns: - Optional[OperatorsSetBase]: The OperatorsSet object with the given name, - or None if not found in the model. - """ - opset_list = [x for x in self.operator_set if x.name == opset_name] - if len(opset_list) > 1: - Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") - return opset_list[0] if opset_list else None - - def append_component(self, tp_model_component: TargetPlatformModelComponent) -> None: + def append_component(self, tp_model_component: TargetPlatformModelComponent): """ Attach a TargetPlatformModel component to the model (like Fusing or OperatorsSet). @@ -674,12 +608,11 @@ def get_info(self) -> Dict[str, Any]: """ return { "Model name": self.name, - "Default quantization config": self.get_default_op_quantization_config().get_info(), "Operators sets": [o.get_info() for o in self.operator_set], "Fusing patterns": [f.get_info() for f in self.fusing_patterns], } - def __validate_model(self) -> None: + def __validate_model(self): """ Validate the model's configuration to ensure its integrity. @@ -721,7 +654,7 @@ def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel': _current_tp_model.reset() return self - def show(self) -> None: + def show(self): """ Display the TargetPlatformModel. diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py index 669a068a7..aa378ff16 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py @@ -16,6 +16,8 @@ from typing import List, Any, Dict from model_compression_toolkit.logger import Logger +from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \ + get_config_options_by_operators_set, is_opset_in_model from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorsSetBase, OperatorSetConcat @@ -137,14 +139,14 @@ def validate_op_sets(self): f'is of type {type(ops2layers)}' # Assert that opset in the current TargetPlatformCapabilities and has a unique name. - is_opset_in_model = _current_tpc.get().tp_model.is_opset_in_model(ops2layers.name) - assert is_opset_in_model, f'{ops2layers.name} is not defined in the target platform model that is associated with the target platform capabilities.' + opset_in_model = is_opset_in_model(_current_tpc.get().tp_model, ops2layers.name) + assert opset_in_model, f'{ops2layers.name} is not defined in the target platform model that is associated with the target platform capabilities.' assert not (ops2layers.name in existing_opset_names), f'OperationsSetToLayers names should be unique, but {ops2layers.name} appears to violate it.' existing_opset_names.append(ops2layers.name) # Assert that a layer does not appear in more than a single OperatorsSet in the TargetPlatformModel. for layer in ops2layers.layers: - qco_by_opset_name = _current_tpc.get().tp_model.get_config_options_by_operators_set(ops2layers.name) + qco_by_opset_name = get_config_options_by_operators_set(_current_tpc.get().tp_model, ops2layers.name) if layer in existing_layers: Logger.critical(f'Found layer {layer.__name__} in more than one ' f'OperatorsSet') # pragma: no cover diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py index ef0cd5713..924069c82 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py @@ -19,6 +19,8 @@ from typing import List, Any, Dict, Tuple from model_compression_toolkit.logger import Logger +from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \ + get_config_options_by_operators_set, get_default_op_quantization_config, get_opset_by_name from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \ OperationsToLayers, OperationsSetToLayers from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent @@ -64,7 +66,7 @@ def get_layers_by_opset_name(self, opset_name: str) -> List[Any]: Returns: List of layers/LayerFilterParams that are attached to the opset name. """ - opset = self.tp_model.get_opset_by_name(opset_name) + opset = get_opset_by_name(self.tp_model, opset_name) if opset is None: Logger.warning(f'{opset_name} was not found in TargetPlatformCapabilities.') return None @@ -165,7 +167,7 @@ def get_default_op_qc(self) -> OpQuantizationConfig: to the TargetPlatformCapabilities. """ - return self.tp_model.get_default_op_quantization_config() + return get_default_op_quantization_config(self.tp_model) def _get_config_options_mapping(self) -> Tuple[Dict[Any, QuantizationConfigOptions], @@ -181,7 +183,7 @@ def _get_config_options_mapping(self) -> Tuple[Dict[Any, QuantizationConfigOptio filterlayer2qco = {} for op2layers in self.op_sets_to_layers.op_sets_to_layers: for l in op2layers.layers: - qco = self.tp_model.get_config_options_by_operators_set(op2layers.name) + qco = get_config_options_by_operators_set(self.tp_model, op2layers.name) if qco is None: qco = self.tp_model.default_qco diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py index 2f658d2f8..9ca5f4643 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py @@ -271,7 +271,7 @@ def generate_tp_model(default_config: OpQuantizationConfig, base_config=base_config) # Define operator sets that use mixed_precision_configuration_options: - conv = schema.OperatorsSet(OPSET_CONV, mixed_precision_configuration_options) + conv = schema.OperatorsSet(schema.OPS_SET_LIST.OPSET_CONV, mixed_precision_configuration_options) fc = schema.OperatorsSet(OPSET_FULLY_CONNECTED, mixed_precision_configuration_options) schema.OperatorsSet(OPSET_BATCH_NORM, default_config_options_16bit) diff --git a/tests/common_tests/test_tp_model.py b/tests/common_tests/test_tp_model.py index 5b1cd5799..4e96a13df 100644 --- a/tests/common_tests/test_tp_model.py +++ b/tests/common_tests/test_tp_model.py @@ -20,6 +20,8 @@ from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR +from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \ + get_config_options_by_operators_set, is_opset_in_model from model_compression_toolkit.target_platform_capabilities.target_platform import \ get_default_quantization_config_options from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, generate_test_op_qc @@ -92,13 +94,13 @@ def test_opset_qco(self): qco_3bit = get_default_quantization_config_options().clone_and_edit(activation_n_bits=3) schema.OperatorsSet(opset_name, qco_3bit) - for op_qc in hm.get_config_options_by_operators_set(opset_name).quantization_config_list: + for op_qc in get_config_options_by_operators_set(hm, opset_name).quantization_config_list: self.assertEqual(op_qc.activation_n_bits, 3) - self.assertTrue(hm.is_opset_in_model(opset_name)) - self.assertFalse(hm.is_opset_in_model("ShouldNotBeInModel")) - self.assertEqual(hm.get_config_options_by_operators_set(opset_name), qco_3bit) - self.assertEqual(hm.get_config_options_by_operators_set("ShouldNotBeInModel"), + self.assertTrue(is_opset_in_model(hm, opset_name)) + self.assertFalse(is_opset_in_model(hm, "ShouldNotBeInModel")) + self.assertEqual(get_config_options_by_operators_set(hm, opset_name), qco_3bit) + self.assertEqual(get_config_options_by_operators_set(hm, "ShouldNotBeInModel"), hm.default_qco) def test_opset_concat(self): @@ -115,8 +117,8 @@ def test_opset_concat(self): schema.OperatorsSet('opset_C') # Just add it without using it in concat schema.OperatorSetConcat([a, b]) self.assertEqual(len(hm.operator_set), 4) - self.assertTrue(hm.is_opset_in_model("opset_A_opset_B")) - self.assertTrue(hm.get_config_options_by_operators_set('opset_A_opset_B') is None) + self.assertTrue(is_opset_in_model(hm, "opset_A_opset_B")) + self.assertTrue(get_config_options_by_operators_set(hm, 'opset_A_opset_B') is None) def test_non_unique_opset(self): hm = schema.TargetPlatformModel(