Skip to content

Commit

Permalink
Remove functionality from schema to schema_functions
Browse files Browse the repository at this point in the history
  • Loading branch information
liord committed Dec 11, 2024
1 parent 52b00d0 commit 0e71a87
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 103 deletions.
5 changes: 3 additions & 2 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
OpQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams


Expand Down Expand Up @@ -585,7 +586,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
_node_qc_options = node_qc_options.quantization_config_list
if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])

Expand All @@ -596,7 +597,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_activation_quantization_params_fn, get_weights_quantization_params_fn
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
get_weights_quantization_fn
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
QuantizationConfigOptions
Expand Down Expand Up @@ -117,7 +118,7 @@ def filter_node_qco_by_graph(node: BaseNode,

if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])

Expand All @@ -128,7 +129,7 @@ def filter_node_qco_by_graph(node: BaseNode,
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.")

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from logging import Logger
import copy
from typing import Any, Dict
from typing import Any, Dict, Optional

from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
TargetPlatformModel, QuantizationConfigOptions, OperatorsSetBase


def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any:
Expand All @@ -35,3 +39,106 @@ def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any:
f'but {k} is not a parameter of {obj_copy}.'
setattr(obj_copy, k, v)
return obj_copy


def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions:
"""
Get the QuantizationConfigOptions of an OperatorsSet by its name.
Args:
operators_set_name (str): Name of the OperatorsSet to get.
Returns:
QuantizationConfigOptions: Quantization configuration options for the given OperatorsSet name.
"""
for op_set in self.operator_set:
if operators_set_name == op_set.name:
return op_set.qc_options
return self.default_qco


def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int:
"""
Get the maximum supported input bit-width.
Args:
op_quantization_config (OpQuantizationConfig): The configuration object from which to retrieve the maximum supported input bit-width.
Returns:
int: Maximum supported input bit-width.
"""
return max(op_quantization_config.supported_input_activation_n_bits)


def get_config_options_by_operators_set(tp_model: TargetPlatformModel,
operators_set_name: str) -> QuantizationConfigOptions:
"""
Get the QuantizationConfigOptions of an OperatorsSet by its name.
Args:
tp_model (TargetPlatformModel): The target platform model containing the operator sets and their configurations.
operators_set_name (str): The name of the OperatorsSet whose quantization configuration options are to be retrieved.
Returns:
QuantizationConfigOptions: The quantization configuration options associated with the specified OperatorsSet,
or the default quantization configuration options if the OperatorsSet is not found.
"""
for op_set in tp_model.operator_set:
if operators_set_name == op_set.name:
return op_set.qc_options
return tp_model.default_qco


def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuantizationConfig:
"""
Get the default OpQuantizationConfig of the TargetPlatformModel.
Args:
tp_model (TargetPlatformModel): The target platform model containing the default quantization configuration.
Returns:
OpQuantizationConfig: The default quantization configuration.
Raises:
AssertionError: If the default quantization configuration list contains more than one configuration option.
"""
assert len(tp_model.default_qco.quantization_config_list) == 1, \
f"Default quantization configuration options must contain only one option, " \
f"but found {len(tp_model.default_qco.quantization_config_list)} configurations."
return tp_model.default_qco.quantization_config_list[0]


def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
"""
Check whether an OperatorsSet is defined in the model.
Args:
tp_model (TargetPlatformModel): The target platform model containing the list of operator sets.
opset_name (str): The name of the OperatorsSet to check for existence.
Returns:
bool: True if an OperatorsSet with the given name exists in the target platform model,
otherwise False.
"""
return opset_name in [x.name for x in tp_model.operator_set]


def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optional[OperatorsSetBase]:
"""
Get an OperatorsSet object from the model by its name.
Args:
tp_model (TargetPlatformModel): The target platform model containing the list of operator sets.
opset_name (str): The name of the OperatorsSet to be retrieved.
Returns:
Optional[OperatorsSetBase]: The OperatorsSet object with the specified name if found.
If no operator set with the specified name is found, None is returned.
Raises:
A critical log message if multiple operator sets with the same name are found.
"""
opset_list = [x for x in tp_model.operator_set if x.name == opset_name]
if len(opset_list) > 1:
Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.")
return opset_list[0] if opset_list else None
101 changes: 17 additions & 84 deletions model_compression_toolkit/target_platform_capabilities/schema/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ class AttributeQuantizationConfig:
weights_n_bits (int): Number of bits to quantize the coefficients.
weights_per_channel_threshold (bool): Indicates whether to quantize the weights per-channel or per-tensor.
enable_weights_quantization (bool): Indicates whether to quantize the model weights or not.
lut_values_bitwidth (Union[int, None]): Number of bits to use when quantizing in a look-up table.
lut_values_bitwidth (Optional[int]): Number of bits to use when quantizing in a look-up table.
If None, defaults to 8 in hptq; otherwise, it uses the provided value.
"""
weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO
weights_n_bits: int = FLOAT_BITWIDTH
weights_per_channel_threshold: bool = False
enable_weights_quantization: bool = False
lut_values_bitwidth: Union[int, None] = None
lut_values_bitwidth: Optional[int] = None

def __post_init__(self):
"""
Expand Down Expand Up @@ -170,7 +170,7 @@ class OpQuantizationConfig:
simd_size: int
signedness: Signedness

def __post_init__(self) -> None:
def __post_init__(self):
"""
Post-initialization processing for input validation.
Expand Down Expand Up @@ -218,16 +218,6 @@ def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs)
# Return a new instance with the updated attribute mapping
return replace(updated_config, attr_weights_configs_mapping=updated_attr_mapping)

@property
def max_input_activation_n_bits(self) -> int:
"""
Get the maximum supported input bit-width.
Returns:
int: Maximum supported input bit-width.
"""
return max(self.supported_input_activation_n_bits)


@dataclass(frozen=True)
class QuantizationConfigOptions:
Expand All @@ -236,12 +226,12 @@ class QuantizationConfigOptions:
Attributes:
quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather.
base_config (Union[OpQuantizationConfig, None]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner.
base_config (Optional[OpQuantizationConfig]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner.
"""
quantization_config_list: List[OpQuantizationConfig]
base_config: Union[OpQuantizationConfig, None] = None
base_config: Optional[OpQuantizationConfig] = None

def __post_init__(self) -> None:
def __post_init__(self):
"""
Post-initialization processing for input validation.
Expand Down Expand Up @@ -320,12 +310,12 @@ def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) ->
updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping))
return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs)

def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Union[Dict[str, str], None]) -> 'QuantizationConfigOptions':
def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Optional[Dict[str, str]]) -> 'QuantizationConfigOptions':
"""
Clones the quantization configurations and updates keys in attribute config mappings.
Args:
layer_attrs_mapping (Union[Dict[str, str], None]): A mapping between attribute names.
layer_attrs_mapping (Optional[Dict[str, str]]): A mapping between attribute names.
Returns:
QuantizationConfigOptions: A new instance of QuantizationConfigOptions with updated attribute keys.
Expand Down Expand Up @@ -361,7 +351,7 @@ class TargetPlatformModelComponent:
Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
"""

def __post_init__(self) -> None:
def __post_init__(self):
"""
Post-initialization to register the component with the current TargetPlatformModel.
"""
Expand All @@ -384,7 +374,7 @@ class OperatorsSetBase(TargetPlatformModelComponent):
Base class to represent a set of a target platform model component of operator set types.
Inherits from TargetPlatformModelComponent.
"""
def __post_init__(self) -> None:
def __post_init__(self):
"""
Post-initialization to ensure the component is registered with the TargetPlatformModel.
Calls the parent class's __post_init__ method to append this component to the current TargetPlatformModel.
Expand All @@ -407,7 +397,7 @@ class OperatorsSet(OperatorsSetBase):
name: str
qc_options: QuantizationConfigOptions = None

def __post_init__(self) -> None:
def __post_init__(self):
"""
Post-initialization processing to mark the operator set as default if applicable.
Expand Down Expand Up @@ -447,7 +437,7 @@ class OperatorSetConcat(OperatorsSetBase):
qc_options: None = field(default=None, init=False)
name: str = None

def __post_init__(self) -> None:
def __post_init__(self):
"""
Post-initialization processing to generate the concatenated name and set it as the `name` attribute.
Expand Down Expand Up @@ -486,7 +476,7 @@ class Fusing(TargetPlatformModelComponent):
operator_groups_list: Tuple[Union[OperatorsSet, OperatorSetConcat]]
name: str = None

def __post_init__(self) -> None:
def __post_init__(self):
"""
Post-initialization processing for input validation and name generation.
Expand All @@ -504,7 +494,6 @@ def __post_init__(self) -> None:
if len(self.operator_groups_list) < 2:
Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover

# if self.name is None:
# Generate the name from the operator groups if not provided
generated_name = '_'.join([x.name for x in self.operator_groups_list])
object.__setattr__(self, 'name', generated_name)
Expand Down Expand Up @@ -578,7 +567,7 @@ class TargetPlatformModel:

SCHEMA_VERSION: int = 1

def __post_init__(self) -> None:
def __post_init__(self):
"""
Post-initialization processing for input validation.
Expand All @@ -592,62 +581,7 @@ def __post_init__(self) -> None:
if len(self.default_qco.quantization_config_list) != 1:
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover

def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions:
"""
Get the QuantizationConfigOptions of an OperatorsSet by its name.
Args:
operators_set_name (str): Name of the OperatorsSet to get.
Returns:
QuantizationConfigOptions: Quantization configuration options for the given OperatorsSet name.
"""
for op_set in self.operator_set:
if operators_set_name == op_set.name:
return op_set.qc_options
return self.default_qco

def get_default_op_quantization_config(self) -> OpQuantizationConfig:
"""
Get the default OpQuantizationConfig of the TargetPlatformModel.
Returns:
OpQuantizationConfig: The default quantization configuration.
"""
assert len(self.default_qco.quantization_config_list) == 1, \
f"Default quantization configuration options must contain only one option, " \
f"but found {len(self.default_qco.quantization_config_list)} configurations."
return self.default_qco.quantization_config_list[0]

def is_opset_in_model(self, opset_name: str) -> bool:
"""
Check whether an OperatorsSet is defined in the model.
Args:
opset_name (str): Name of the OperatorsSet to check.
Returns:
bool: True if the OperatorsSet exists, False otherwise.
"""
return opset_name in [x.name for x in self.operator_set]

def get_opset_by_name(self, opset_name: str) -> Optional[OperatorsSetBase]:
"""
Get an OperatorsSet object from the model by its name.
Args:
opset_name (str): Name of the OperatorsSet to retrieve.
Returns:
Optional[OperatorsSetBase]: The OperatorsSet object with the given name,
or None if not found in the model.
"""
opset_list = [x for x in self.operator_set if x.name == opset_name]
if len(opset_list) > 1:
Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.")
return opset_list[0] if opset_list else None

def append_component(self, tp_model_component: TargetPlatformModelComponent) -> None:
def append_component(self, tp_model_component: TargetPlatformModelComponent):
"""
Attach a TargetPlatformModel component to the model (like Fusing or OperatorsSet).
Expand All @@ -674,12 +608,11 @@ def get_info(self) -> Dict[str, Any]:
"""
return {
"Model name": self.name,
"Default quantization config": self.get_default_op_quantization_config().get_info(),
"Operators sets": [o.get_info() for o in self.operator_set],
"Fusing patterns": [f.get_info() for f in self.fusing_patterns],
}

def __validate_model(self) -> None:
def __validate_model(self):
"""
Validate the model's configuration to ensure its integrity.
Expand Down Expand Up @@ -721,7 +654,7 @@ def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel':
_current_tp_model.reset()
return self

def show(self) -> None:
def show(self):
"""
Display the TargetPlatformModel.
Expand Down
Loading

0 comments on commit 0e71a87

Please sign in to comment.