Skip to content

Commit

Permalink
Refactor Target Platform Capabilities Design
Browse files Browse the repository at this point in the history
- Create a new `schema` package to house all target platform modeling classes
- Introduce a new versioning system with minor and patch versions

Additional Changes:
- Update existing target platform models to adhere to the new versioning convention
- Add necessary metadata
- Correct all import statements
- Update and enhance tests to reflect the design changes
  • Loading branch information
liord committed Nov 24, 2024
1 parent 432ae5b commit 2f794ad
Show file tree
Hide file tree
Showing 69 changed files with 1,750 additions and 1,517 deletions.
1 change: 1 addition & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# Metadata fields
MCT_VERSION = 'mct_version'
TPC_VERSION = 'tpc_version'
TPC_SCHEMA = 'tpc_schema'

WEIGHTS_SIGNED = True
# Minimal threshold to use for quantization ranges:
Expand Down
5 changes: 3 additions & 2 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \
TargetPlatformCapabilities, LayerFilterParams, OpQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams
from model_compression_toolkit.target_platform_capabilities.schema.v1 import OpQuantizationConfig, \
QuantizationConfigOptions


class BaseNode:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
QuantizationConfigOptions
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.schema.v1 import QuantizationConfigOptions


def compute_resource_utilization_data(in_model: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from model_compression_toolkit.core import QuantizationConfig
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig, \
OpQuantizationConfig
from model_compression_toolkit.logger import Logger


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
QuantizationErrorMethod
from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig, \
OpQuantizationConfig


##########################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import numpy as np
from typing import Dict, Union

from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod, Signedness
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.schema.v1 import Signedness
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
from model_compression_toolkit.core.common.quantization import quantization_params_generation
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
get_weights_quantization_fn
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
from model_compression_toolkit.target_platform_capabilities.schema.v1 import OpQuantizationConfig, \
QuantizationConfigOptions


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig


def apply_activation_bias_correction_to_graph(graph: Graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig


def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod, \
AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig


class BatchNormalizationReconstruction(common.BaseSubstitution):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod, \
AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
set_quantization_configs_to_node
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
Expand Down
11 changes: 7 additions & 4 deletions model_compression_toolkit/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from typing import Dict, Any
from model_compression_toolkit.constants import MCT_VERSION, TPC_VERSION, OPERATORS_SCHEDULING, FUSED_NODES_MAPPING, \
CUTS, MAX_CUT, OP_ORDER, OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS
CUTS, MAX_CUT, OP_ORDER, OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS, TPC_SCHEMA
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import SchedulerInfo
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities

Expand Down Expand Up @@ -43,13 +43,16 @@ def create_model_metadata(tpc: TargetPlatformCapabilities,
def get_versions_dict(tpc) -> Dict:
"""
Returns: A dictionary with TPC and MCT versions.
Returns: A dictionary with TPC, MCT and TPC-Schema versions.
"""
# imported inside to avoid circular import error
from model_compression_toolkit import __version__ as mct_version
tpc_version = f'{tpc.name}.{tpc.version}'
return {MCT_VERSION: mct_version, TPC_VERSION: tpc_version}
tpc_version = f'{tpc.tp_model.tpc_minor_version}.{tpc.tp_model.tpc_patch_version}'
tpc_schema = f'{tpc.tp_model.SCHEMA_VERSION}'
return {MCT_VERSION: mct_version,
TPC_VERSION: tpc_version,
TPC_SCHEMA: tpc_schema}


def get_scheduler_metadata(scheduler_info: SchedulerInfo) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Loading

0 comments on commit 2f794ad

Please sign in to comment.