Skip to content

Commit

Permalink
Refactor Target Platform Capabilities - Phase 2 (#1290)
Browse files Browse the repository at this point in the history
* **Refactor Target Platform Capabilities - Phase 2**

- Convert all schema classes to immutable dataclasses, replacing existing methods with equivalent dataclass methods (e.g., `replace`).
- Ensure all schema classes are immutable to enhance reliability and maintain consistency.
- Update target platform model versions to align with the new class structure.
- Refactor tests to support and validate the updated class types and functionality.
- Remove functionality from schema to schema_functions


---------

Co-authored-by: liord <[email protected]>
  • Loading branch information
lior-dikstein and liord authored Dec 12, 2024
1 parent 262db8a commit dbc93b7
Show file tree
Hide file tree
Showing 27 changed files with 600 additions and 590 deletions.
5 changes: 3 additions & 2 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
OpQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams


Expand Down Expand Up @@ -584,7 +585,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
_node_qc_options = node_qc_options.quantization_config_list
if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])

Expand All @@ -595,7 +596,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_activation_quantization_params_fn, get_weights_quantization_params_fn
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
get_weights_quantization_fn
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
QuantizationConfigOptions
Expand Down Expand Up @@ -117,7 +118,7 @@ def filter_node_qco_by_graph(node: BaseNode,

if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])

Expand All @@ -128,7 +129,7 @@ def filter_node_qco_by_graph(node: BaseNode,
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.")

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,95 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
from typing import Any, Dict
from logging import Logger
from typing import Optional

from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
TargetPlatformModel, QuantizationConfigOptions, OperatorsSetBase

def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any:

def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int:
"""
Get the maximum supported input bit-width.
Args:
op_quantization_config (OpQuantizationConfig): The configuration object from which to retrieve the maximum supported input bit-width.
Returns:
int: Maximum supported input bit-width.
"""
return max(op_quantization_config.supported_input_activation_n_bits)


def get_config_options_by_operators_set(tp_model: TargetPlatformModel,
operators_set_name: str) -> QuantizationConfigOptions:
"""
Get the QuantizationConfigOptions of an OperatorsSet by its name.
Args:
tp_model (TargetPlatformModel): The target platform model containing the operator sets and their configurations.
operators_set_name (str): The name of the OperatorsSet whose quantization configuration options are to be retrieved.
Returns:
QuantizationConfigOptions: The quantization configuration options associated with the specified OperatorsSet,
or the default quantization configuration options if the OperatorsSet is not found.
"""
for op_set in tp_model.operator_set:
if operators_set_name == op_set.name:
return op_set.qc_options
return tp_model.default_qco


def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuantizationConfig:
"""
Clones the given object and edit some of its parameters.
Get the default OpQuantizationConfig of the TargetPlatformModel.
Args:
obj: An object to clone.
**kwargs: Keyword arguments to edit in the cloned object.
tp_model (TargetPlatformModel): The target platform model containing the default quantization configuration.
Returns:
Edited copy of the given object.
OpQuantizationConfig: The default quantization configuration.
Raises:
AssertionError: If the default quantization configuration list contains more than one configuration option.
"""
assert len(tp_model.default_qco.quantization_config_list) == 1, \
f"Default quantization configuration options must contain only one option, " \
f"but found {len(tp_model.default_qco.quantization_config_list)} configurations." # pragma: no cover
return tp_model.default_qco.quantization_config_list[0]


def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
"""
Check whether an OperatorsSet is defined in the model.
Args:
tp_model (TargetPlatformModel): The target platform model containing the list of operator sets.
opset_name (str): The name of the OperatorsSet to check for existence.
Returns:
bool: True if an OperatorsSet with the given name exists in the target platform model,
otherwise False.
"""
return opset_name in [x.name for x in tp_model.operator_set]

obj_copy = copy.deepcopy(obj)
for k, v in kwargs.items():
assert hasattr(obj_copy,
k), f'Edit parameter is possible only for existing parameters in the given object, ' \
f'but {k} is not a parameter of {obj_copy}.'
setattr(obj_copy, k, v)
return obj_copy

def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optional[OperatorsSetBase]:
"""
Get an OperatorsSet object from the model by its name.
Args:
tp_model (TargetPlatformModel): The target platform model containing the list of operator sets.
opset_name (str): The name of the OperatorsSet to be retrieved.
Returns:
Optional[OperatorsSetBase]: The OperatorsSet object with the specified name if found.
If no operator set with the specified name is found, None is returned.
Raises:
A critical log message if multiple operator sets with the same name are found.
"""
opset_list = [x for x in tp_model.operator_set if x.name == opset_name]
if len(opset_list) > 1:
Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") # pragma: no cover
return opset_list[0] if opset_list else None
Loading

0 comments on commit dbc93b7

Please sign in to comment.