From 4dac6a37c31590dbd24b84a7112301837d67e7b7 Mon Sep 17 00:00:00 2001 From: Elad Cohen <78862769+elad-c@users.noreply.github.com> Date: Sun, 15 Sep 2024 20:07:15 +0300 Subject: [PATCH] Add TPC.v4 with quantization preservation. (#1214) * Add TPC.v4 with quantization preservation. Handle quantization preservation in quantization options filtering. --- .../set_node_quantization_config.py | 68 ++++- .../target_platform_capabilities.py | 6 + .../tpc_models/imx500_tpc/v4/__init__.py | 16 ++ .../tpc_models/imx500_tpc/v4/tp_model.py | 258 ++++++++++++++++++ .../tpc_models/imx500_tpc/v4/tpc_keras.py | 133 +++++++++ .../tpc_models/imx500_tpc/v4/tpc_pytorch.py | 113 ++++++++ .../feature_networks/activation_16bit_test.py | 4 +- .../test_features_runner.py | 1 + .../feature_models/activation_16bit_test.py | 13 +- .../model_tests/test_feature_models_runner.py | 1 + 10 files changed, 605 insertions(+), 8 deletions(-) create mode 100644 model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py create mode 100644 model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py create mode 100644 model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py create mode 100644 model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py diff --git a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py index 5dccef862..b85fc8571 100644 --- a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py @@ -79,6 +79,72 @@ def set_quantization_configuration_to_graph(graph: Graph, return graph +def filter_node_qco_by_graph(node: BaseNode, + tpc: TargetPlatformCapabilities, + graph: Graph, + node_qc_options: QuantizationConfigOptions + ) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: + """ + Filter quantization config options that don't match the graph. + A node may have several quantization config options with 'activation_n_bits' values, and + the next nodes in the graph may support different bit-width as input activation. This function + filters out quantization config that don't comply to these attributes. + + Args: + node: Node for filtering. + tpc: TPC to extract the QuantizationConfigOptions for the next nodes. + graph: Graph object. + node_qc_options: Node's QuantizationConfigOptions. + + Returns: + A base config (OpQuantizationConfig) and a config options list (list of OpQuantizationConfig) + that are compatible with next nodes supported input bit-widths. + + """ + # Filter quantization config options that don't match the graph. + _base_config = node_qc_options.base_config + _node_qc_options = node_qc_options.quantization_config_list + + # Build next_nodes list by appending to the node's next nodes list all nodes that are quantization preserving. + _next_nodes = graph.get_next_nodes(node) + next_nodes = [] + while len(_next_nodes): + n = _next_nodes.pop(0) + qco = n.get_qco(tpc) + qp = [qc.quantization_preserving for qc in qco.quantization_config_list] + if not all(qp) and any(qp): + Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.') + if qp[0]: + _next_nodes.extend(graph.get_next_nodes(n)) + next_nodes.append(n) + + if len(next_nodes): + next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes] + next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits + for qc_opts in next_nodes_qc_options + for op_cfg in qc_opts.quantization_config_list]) + + # Filter node's QC options that match next nodes input bit-width. + _node_qc_options = [_option for _option in _node_qc_options + if _option.activation_n_bits <= next_nodes_supported_input_bitwidth] + if len(_node_qc_options) == 0: + Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.") + + # Verify base config match + if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits + for qc_opt in next_nodes_qc_options]): + # base_config activation bits doesn't match next node supported input bit-width -> replace with + # a qco from quantization_config_list with maximum activation bit-width. + if len(_node_qc_options) > 0: + output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)} + _base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]] + Logger.warning(f"Node {node} base quantization config changed to match Graph and TPC configuration.\nCause: {node} -> {next_nodes}.") + else: + Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.") # pragma: no cover + + return _base_config, _node_qc_options + + def set_quantization_configs_to_node(node: BaseNode, graph: Graph, quant_config: QuantizationConfig, @@ -99,7 +165,7 @@ def set_quantization_configs_to_node(node: BaseNode, manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None. """ node_qc_options = node.get_qco(tpc) - base_config, node_qc_options_list = node.filter_node_qco_by_graph(tpc, graph.get_next_nodes(node), node_qc_options) + base_config, node_qc_options_list = filter_node_qco_by_graph(node, tpc, graph, node_qc_options) # If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override, # and update base_config accordingly. diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py index e7e08734e..f08ec5e14 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py @@ -42,6 +42,8 @@ def get_tpc_dict_by_fw(fw_name): get_keras_tpc as get_keras_tpc_v3 from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import \ get_keras_tpc as get_keras_tpc_v3_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v4 # Keras: TPC versioning tpc_models_dict = {'v1': get_keras_tpc_v1, @@ -51,6 +53,7 @@ def get_tpc_dict_by_fw(fw_name): 'v2_lut': get_keras_tpc_v2_lut, 'v3': get_keras_tpc_v3, 'v3_lut': get_keras_tpc_v3_lut, + 'v4': get_keras_tpc_v4, LATEST: get_keras_tpc_latest} elif fw_name == PYTORCH: ############################### @@ -73,6 +76,8 @@ def get_tpc_dict_by_fw(fw_name): get_pytorch_tpc as get_pytorch_tpc_v3 from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \ get_pytorch_tpc as get_pytorch_tpc_v3_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v4 # Pytorch: TPC versioning tpc_models_dict = {'v1': get_pytorch_tpc_v1, @@ -82,6 +87,7 @@ def get_tpc_dict_by_fw(fw_name): 'v2_lut': get_pytorch_tpc_v2_lut, 'v3': get_pytorch_tpc_v3, 'v3_lut': get_pytorch_tpc_v3_lut, + 'v4': get_pytorch_tpc_v4, LATEST: get_pytorch_tpc_latest} if tpc_models_dict is not None: return tpc_models_dict diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py new file mode 100644 index 000000000..a9b845dfa --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +__version__ = 'v4' diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py new file mode 100644 index 000000000..72ea15029 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py @@ -0,0 +1,258 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import List, Tuple + +import model_compression_toolkit as mct +from model_compression_toolkit.constants import FLOAT_BITWIDTH +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS +from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \ + TargetPlatformModel, Signedness +from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \ + AttributeQuantizationConfig + +tp = mct.target_platform + + +def get_tp_model() -> TargetPlatformModel: + """ + A method that generates a default target platform model, with base 8-bit quantization configuration and 8, 4, 2 + bits configuration list for mixed-precision quantization. + NOTE: in order to generate a target platform model with different configurations but with the same Operators Sets + (for tests, experiments, etc.), use this method implementation as a test-case, i.e., override the + 'get_op_quantization_configs' method and use its output to call 'generate_tp_model' with your configurations. + This version enables metadata by default. + + Returns: A TargetPlatformModel object. + + """ + base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() + return generate_tp_model(default_config=default_config, + base_config=base_config, + mixed_precision_cfg_list=mixed_precision_cfg_list, + name='imx500_tp_model') + + +def get_op_quantization_configs() -> \ + Tuple[OpQuantizationConfig, List[OpQuantizationConfig], OpQuantizationConfig]: + """ + Creates a default configuration object for 8-bit quantization, to be used to set a default TargetPlatformModel. + In addition, creates a default configuration objects list (with 8, 4 and 2 bit quantization) to be used as + default configuration for mixed-precision quantization. + + Returns: An OpQuantizationConfig config object and a list of OpQuantizationConfig objects. + + """ + + # TODO: currently, we don't want to quantize any attribute but the kernel by default, + # to preserve the current behavior of MCT, so quantization is disabled for all other attributes. + # Other quantization parameters are set to what we eventually want to quantize by default + # when we enable multi-attributes quantization - THIS NEED TO BE MODIFIED IN ALL TP MODELS! + + # define a default quantization config for all non-specified weights attributes. + default_weight_attr_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + weights_n_bits=8, + weights_per_channel_threshold=False, + enable_weights_quantization=False, # TODO: this will changed to True once implementing multi-attributes quantization + lut_values_bitwidth=None) + + # define a quantization config to quantize the kernel (for layers where there is a kernel attribute). + kernel_base_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.SYMMETRIC, + weights_n_bits=8, + weights_per_channel_threshold=True, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + # define a quantization config to quantize the bias (for layers where there is a bias attribute). + bias_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + weights_n_bits=FLOAT_BITWIDTH, + weights_per_channel_threshold=False, + enable_weights_quantization=False, + lut_values_bitwidth=None) + + # Create a quantization config. + # A quantization configuration defines how an operator + # should be quantized on the modeled hardware: + + # We define a default config for operation without kernel attribute. + # This is the default config that should be used for non-linear operations. + eight_bits_default = tp.OpQuantizationConfig( + default_weight_attr_config=default_weight_attr_config, + attr_weights_configs_mapping={}, + activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + supported_input_activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + fixed_scale=None, + fixed_zero_point=None, + simd_size=32, + signedness=Signedness.AUTO) + + # We define an 8-bit config for linear operations quantization, that include a kernel and bias attributes. + linear_eight_bits = tp.OpQuantizationConfig( + default_weight_attr_config=default_weight_attr_config, + attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, + activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + supported_input_activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + fixed_scale=None, + fixed_zero_point=None, + simd_size=32, + signedness=Signedness.AUTO) + + # To quantize a model using mixed-precision, create + # a list with more than one OpQuantizationConfig. + # In this example, we quantize some operations' weights + # using 2, 4 or 8 bits, and when using 2 or 4 bits, it's possible + # to quantize the operations' activations using LUT. + four_bits = linear_eight_bits.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 4}}, + simd_size=linear_eight_bits.simd_size * 2) + two_bits = linear_eight_bits.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 2}}, + simd_size=linear_eight_bits.simd_size * 4) + + mixed_precision_cfg_list = [linear_eight_bits, four_bits, two_bits] + + return linear_eight_bits, mixed_precision_cfg_list, eight_bits_default + + +def generate_tp_model(default_config: OpQuantizationConfig, + base_config: OpQuantizationConfig, + mixed_precision_cfg_list: List[OpQuantizationConfig], + name: str) -> TargetPlatformModel: + """ + Generates TargetPlatformModel with default defined Operators Sets, based on the given base configuration and + mixed-precision configurations options list. + + Args + default_config: A default OpQuantizationConfig to set as the TP model default configuration. + base_config: An OpQuantizationConfig to set as the TargetPlatformModel base configuration for mixed-precision purposes only. + mixed_precision_cfg_list: A list of OpQuantizationConfig to be used as the TP model mixed-precision + quantization configuration options. + name: The name of the TargetPlatformModel. + + Returns: A TargetPlatformModel object. + + """ + # Create a QuantizationConfigOptions, which defines a set + # of possible configurations to consider when quantizing a set of operations (in mixed-precision, for example). + # If the QuantizationConfigOptions contains only one configuration, + # this configuration will be used for the operation quantization: + default_configuration_options = tp.QuantizationConfigOptions([default_config]) + + # Create a QuantizationConfigOptions for quantizing constants in functional ops. + # Constant configuration is similar to the default eight bit configuration except for PoT + # quantization method for the constant. + # Since the constants are not named attributes of the layer, we use the default_weight_attr_config to + # define the desired quantization properties for them. + const_config = default_config.clone_and_edit( + default_weight_attr_config=default_config.default_weight_attr_config.clone_and_edit( + enable_weights_quantization=True, weights_per_channel_threshold=True, + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO)) + const_configuration_options = tp.QuantizationConfigOptions([const_config]) + + # 16 bits inputs and outputs. Currently, only defined for consts since they are used in operators that + # support 16 bit as input and output. + const_config_input16 = const_config.clone_and_edit( + supported_input_activation_n_bits=(8, 16)) + const_config_input16_output16 = const_config_input16.clone_and_edit( + activation_n_bits=16, signedness=Signedness.SIGNED) + const_configuration_options_inout16 = tp.QuantizationConfigOptions([const_config_input16_output16, + const_config_input16], + base_config=const_config_input16) + + const_config_input16_per_tensor = const_config.clone_and_edit( + supported_input_activation_n_bits=(8, 16), + default_weight_attr_config=default_config.default_weight_attr_config.clone_and_edit( + enable_weights_quantization=True, weights_per_channel_threshold=True, + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO) + ) + const_config_input16_output16_per_tensor = const_config_input16_per_tensor.clone_and_edit( + activation_n_bits=16, signedness=Signedness.SIGNED) + const_configuration_options_inout16_per_tensor = tp.QuantizationConfigOptions([const_config_input16_output16_per_tensor, + const_config_input16_per_tensor], + base_config=const_config_input16_per_tensor) + + # Create a TargetPlatformModel and set its default quantization config. + # This default configuration will be used for all operations + # unless specified otherwise (see OperatorsSet, for example): + generated_tpm = tp.TargetPlatformModel(default_configuration_options, add_metadata=True, name=name) + + # To start defining the model's components (such as operator sets, and fusing patterns), + # use 'with' the TargetPlatformModel instance, and create them as below: + with generated_tpm: + # Create an OperatorsSet to represent a set of operations. + # Each OperatorsSet has a unique label. + # If a quantization configuration options is passed, these options will + # be used for operations that will be attached to this set's label. + # Otherwise, it will be a configure-less set (used in fusing): + + generated_tpm.set_simd_padding(is_simd_padding=True) + + # May suit for operations like: Dropout, Reshape, etc. + default_qco = tp.get_default_quantization_config_options() + tp.OperatorsSet("NoQuantization", + default_qco.clone_and_edit(enable_activation_quantization=False) + .clone_and_edit_weight_attribute(enable_weights_quantization=False)) + tp.OperatorsSet("QuantizationPreserving", + default_qco.clone_and_edit(enable_activation_quantization=False, + quantization_preserving=True) + .clone_and_edit_weight_attribute(enable_weights_quantization=False)) + tp.OperatorsSet("DimensionManipulationOps", + default_qco.clone_and_edit(enable_activation_quantization=False, + quantization_preserving=True, + supported_input_activation_n_bits=(8, 16)) + .clone_and_edit_weight_attribute(enable_weights_quantization=False)) + tp.OperatorsSet("MergeOps", const_configuration_options_inout16_per_tensor) + + # Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects + mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list, + base_config=base_config) + + # Define operator sets that use mixed_precision_configuration_options: + conv = tp.OperatorsSet("Conv", mixed_precision_configuration_options) + fc = tp.OperatorsSet("FullyConnected", mixed_precision_configuration_options) + + # Define operations sets without quantization configuration + # options (useful for creating fusing patterns, for example): + any_relu = tp.OperatorsSet("AnyReLU") + add = tp.OperatorsSet("Add", const_configuration_options_inout16) + sub = tp.OperatorsSet("Sub", const_configuration_options_inout16) + mul = tp.OperatorsSet("Mul", const_configuration_options_inout16) + div = tp.OperatorsSet("Div", const_configuration_options) + prelu = tp.OperatorsSet("PReLU") + swish = tp.OperatorsSet("Swish") + sigmoid = tp.OperatorsSet("Sigmoid") + tanh = tp.OperatorsSet("Tanh") + + # Combine multiple operators into a single operator to avoid quantization between + # them. To do this we define fusing patterns using the OperatorsSets that were created. + # To group multiple sets with regard to fusing, an OperatorSetConcat can be created + activations_after_conv_to_fuse = tp.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) + activations_after_fc_to_fuse = tp.OperatorSetConcat(any_relu, swish, sigmoid) + any_binary = tp.OperatorSetConcat(add, sub, mul, div) + + # ------------------- # + # Fusions + # ------------------- # + tp.Fusing([conv, activations_after_conv_to_fuse]) + tp.Fusing([fc, activations_after_fc_to_fuse]) + tp.Fusing([any_binary, any_relu]) + + return generated_tpm diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py new file mode 100644 index 000000000..b403d6453 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py @@ -0,0 +1,133 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import tensorflow as tf +from packaging import version + +from model_compression_toolkit.defaultdict import DefaultDict +from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_DEPTHWISE_KERNEL, \ + KERAS_KERNEL, BIAS_ATTR, BIAS + +if FOUND_SONY_CUSTOM_LAYERS: + from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess + +if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ + Conv2DTranspose, Identity, Concatenate +else: + from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ + Conv2DTranspose, Identity, Concatenate + +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import get_tp_model +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4 import __version__ as TPC_VERSION + +tp = mct.target_platform + + +def get_keras_tpc() -> tp.TargetPlatformCapabilities: + """ + get a Keras TargetPlatformCapabilities object with default operation sets to layers mapping. + + Returns: a Keras TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + imx500_tpc_tp_model = get_tp_model() + return generate_keras_tpc(name='imx500_tpc_keras_tpc', tp_model=imx500_tpc_tp_model) + + +def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): + """ + Generates a TargetPlatformCapabilities object with default operation sets to layers mapping. + + Args: + name: Name of the TargetPlatformCapabilities. + tp_model: TargetPlatformModel object. + + Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + + keras_tpc = tp.TargetPlatformCapabilities(tp_model, name=name, version=TPC_VERSION) + + no_quant_list = [tf.quantization.fake_quant_with_min_max_vars, + tf.math.argmax, + tf.shape, + tf.math.equal, + tf.nn.top_k, + tf.image.combined_non_max_suppression, + tf.compat.v1.shape] + quantization_preserving = [Cropping2D, + ZeroPadding2D, + Dropout, + MaxPooling2D, + tf.split, + tf.gather, + tf.cast, + tf.unstack, + tf.compat.v1.gather, + tf.__operators__.getitem, + tf.strided_slice] + quantization_preserving_list_16bit_input = [Reshape, + tf.reshape, + Permute, + tf.transpose, + Flatten] + + if FOUND_SONY_CUSTOM_LAYERS: + no_quant_list.append(SSDPostProcess) + + with keras_tpc: + tp.OperationsSetToLayers("NoQuantization", no_quant_list) + tp.OperationsSetToLayers("QuantizationPreserving", quantization_preserving) + tp.OperationsSetToLayers("DimensionManipulationOps", quantization_preserving_list_16bit_input) + tp.OperationsSetToLayers("MergeOps", [tf.stack, tf.concat, Concatenate]) + tp.OperationsSetToLayers("Conv", + [Conv2D, + DepthwiseConv2D, + Conv2DTranspose, + tf.nn.conv2d, + tf.nn.depthwise_conv2d, + tf.nn.conv2d_transpose], + # we provide attributes mapping that maps each layer type in the operations set + # that has weights attributes with provided quantization config (in the tp model) to + # its framework-specific attribute name. + # note that a DefaultDict should be provided if not all the layer types in the + # operation set are provided separately in the mapping. + attr_mapping={ + KERNEL_ATTR: DefaultDict({ + DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL, + tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}) + tp.OperationsSetToLayers("FullyConnected", [Dense], + attr_mapping={KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}) + tp.OperationsSetToLayers("AnyReLU", [tf.nn.relu, + tf.nn.relu6, + tf.nn.leaky_relu, + ReLU, + LeakyReLU, + tp.LayerFilterParams(Activation, activation="relu"), + tp.LayerFilterParams(Activation, activation="leaky_relu")]) + tp.OperationsSetToLayers("Add", [tf.add, Add]) + tp.OperationsSetToLayers("Sub", [tf.subtract, Subtract]) + tp.OperationsSetToLayers("Mul", [tf.math.multiply, Multiply]) + tp.OperationsSetToLayers("Div", [tf.math.divide, tf.math.truediv]) + tp.OperationsSetToLayers("PReLU", [PReLU]) + tp.OperationsSetToLayers("Swish", [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")]) + tp.OperationsSetToLayers("Sigmoid", [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")]) + tp.OperationsSetToLayers("Tanh", [tf.nn.tanh, tp.LayerFilterParams(Activation, activation="tanh")]) + + return keras_tpc diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py new file mode 100644 index 000000000..9b4bf4e91 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py @@ -0,0 +1,113 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import operator + +import torch +from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \ + chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract +from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d +from torch.nn import Dropout, Flatten, Hardtanh, Identity +from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU +from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu + +from model_compression_toolkit.defaultdict import DefaultDict +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, PYTORCH_KERNEL, \ + BIAS +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import get_tp_model +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4 import __version__ as TPC_VERSION + +tp = mct.target_platform + + +def get_pytorch_tpc() -> tp.TargetPlatformCapabilities: + """ + get a Pytorch TargetPlatformCapabilities object with default operation sets to layers mapping. + + Returns: a Pytorch TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + imx500_tpc_tp_model = get_tp_model() + return generate_pytorch_tpc(name='imx500_tpc_pytorch_tpc', tp_model=imx500_tpc_tp_model) + + +def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel): + """ + Generates a TargetPlatformCapabilities object with default operation sets to layers mapping. + Args: + name: Name of the TargetPlatformModel. + tp_model: TargetPlatformModel object. + Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + + pytorch_tpc = tp.TargetPlatformCapabilities(tp_model, + name=name, + version=TPC_VERSION) + + # we provide attributes mapping that maps each layer type in the operations set + # that has weights attributes with provided quantization config (in the tp model) to + # its framework-specific attribute name. + # note that a DefaultDict should be provided if not all the layer types in the + # operation set are provided separately in the mapping. + pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)} + + with pytorch_tpc: + tp.OperationsSetToLayers("NoQuantization", [torch.Tensor.size, + equal, + argmax, + topk]) + tp.OperationsSetToLayers("QuantizationPreserving", [Dropout, + dropout, + split, + chunk, + unbind, + gather, + MaxPool2d]) + tp.OperationsSetToLayers("DimensionManipulationOps", [Flatten, + flatten, + operator.getitem, + reshape, + unsqueeze, + squeeze, + permute, + transpose]) + tp.OperationsSetToLayers("MergeOps", + [torch.stack, torch.cat, torch.concat, torch.concatenate]) + + tp.OperationsSetToLayers("Conv", [Conv2d, ConvTranspose2d], + attr_mapping=pytorch_linear_attr_mapping) + tp.OperationsSetToLayers("FullyConnected", [Linear], + attr_mapping=pytorch_linear_attr_mapping) + tp.OperationsSetToLayers("AnyReLU", [torch.relu, + ReLU, + ReLU6, + LeakyReLU, + relu, + relu6, + leaky_relu, + tp.LayerFilterParams(Hardtanh, min_val=0), + tp.LayerFilterParams(hardtanh, min_val=0)]) + + tp.OperationsSetToLayers("Add", [operator.add, add]) + tp.OperationsSetToLayers("Sub", [operator.sub, sub, subtract]) + tp.OperationsSetToLayers("Mul", [operator.mul, mul, multiply]) + tp.OperationsSetToLayers("Div", [operator.truediv, div, divide]) + tp.OperationsSetToLayers("PReLU", [PReLU, prelu]) + tp.OperationsSetToLayers("Swish", [SiLU, silu, Hardswish, hardswish]) + tp.OperationsSetToLayers("Sigmoid", [Sigmoid, sigmoid]) + tp.OperationsSetToLayers("Tanh", [Tanh, tanh]) + + return pytorch_tpc diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py index b36714131..7b4e86d05 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py @@ -31,7 +31,7 @@ class Activation16BitTest(BaseKerasFeatureNetworkTest): def get_tpc(self): - tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v3') + tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v4') # Force Mul base_config to 16bit only mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] @@ -45,6 +45,8 @@ def create_networks(self): x = tf.add(x, np.ones((3,), dtype=np.float32)) x1 = tf.subtract(x, np.ones((3,), dtype=np.float32)) x = tf.multiply(x, x1) + x = tf.reshape(x, (-1, 4, 4, 8, 3)) + x = tf.reshape(x, (-1, 16, 8, 3)) x = tf.keras.layers.Conv2D(3, 1)(x) outputs = tf.divide(x, 2*np.ones((3,), dtype=np.float32)) return keras.Model(inputs=inputs, outputs=outputs) diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 10a5a8594..ec1fffdf3 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -813,6 +813,7 @@ def test_keras_tpcs(self): TpcTest(f'{C.IMX500_TP_MODEL}.v2_lut', self).run_test() TpcTest(f'{C.IMX500_TP_MODEL}.v3', self).run_test() TpcTest(f'{C.IMX500_TP_MODEL}.v3_lut', self).run_test() + TpcTest(f'{C.IMX500_TP_MODEL}.v4', self).run_test() TpcTest(f'{C.TFLITE_TP_MODEL}.v1', self).run_test() TpcTest(f'{C.QNNPACK_TP_MODEL}.v1', self).run_test() diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py index 44b660092..bf1f3bf74 100644 --- a/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py @@ -19,7 +19,7 @@ from model_compression_toolkit.constants import PYTORCH from model_compression_toolkit.core import MixedPrecisionQuantizationConfig from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL -from model_compression_toolkit.core.pytorch.constants import CPU +from model_compression_toolkit.core.pytorch.utils import get_working_device from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest @@ -43,6 +43,8 @@ def forward(self, x): x1 = torch.add(x, self.add_const) x = torch.sub(x, self.sub_const) x = torch.mul(x, x1) + x = torch.reshape(x, (-1, 3, 2*(1+int(self.use_concat)), 4, 8)) + x = torch.reshape(x, (-1, 3, 8*(1+int(self.use_concat)), 8)) x = self.conv(x) x = torch.divide(x, self.div_const) return x @@ -51,7 +53,7 @@ def forward(self, x): class Activation16BitTest(BasePytorchFeatureNetworkTest): def get_tpc(self): - tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v3') + tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v4') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] tpc.layer2qco[torch.mul].base_config = mul_op_set.qc_options.base_config @@ -62,10 +64,9 @@ def create_networks(self): return Activation16BitNet() def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - x = torch.from_numpy(input_x[0].astype('float32')) - out_f = float_model(x) - quantized_model = quantized_model.to(CPU) - out_q = quantized_model(x.to(CPU)) + x = torch.from_numpy(input_x[0].astype('float32')).to(get_working_device()) + out_f = float_model.to(get_working_device())(x) + out_q = quantized_model(x) self.unit_test.assertTrue(out_f.shape == out_q.shape, "Output shape mismatch.") mul1_act_quant = quantized_model.mul_activation_holder_quantizer diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 804958ee5..d489825eb 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -682,6 +682,7 @@ def test_torch_tpcs(self): TpcTest(f'{C.IMX500_TP_MODEL}.v2_lut', self).run_test() TpcTest(f'{C.IMX500_TP_MODEL}.v3', self).run_test() TpcTest(f'{C.IMX500_TP_MODEL}.v3_lut', self).run_test() + TpcTest(f'{C.IMX500_TP_MODEL}.v4', self).run_test() TpcTest(f'{C.TFLITE_TP_MODEL}.v1', self).run_test() TpcTest(f'{C.QNNPACK_TP_MODEL}.v1', self).run_test()