Skip to content

Commit

Permalink
**Refactor Target Platform Capabilities - Phase 2**
Browse files Browse the repository at this point in the history
- Convert all schema classes to immutable dataclasses, replacing existing methods with equivalent dataclass methods (e.g., `replace`).
- Ensure all schema classes are strictly immutable to enhance reliability and maintain consistency.
- Update target platform model versions to align with the new class structure.
- Refactor tests to support and validate the updated class types and functionality.
  • Loading branch information
liord committed Dec 10, 2024
1 parent e538597 commit 951568c
Show file tree
Hide file tree
Showing 22 changed files with 525 additions and 534 deletions.
889 changes: 440 additions & 449 deletions model_compression_toolkit/target_platform_capabilities/schema/v1.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def generate_tp_model(default_config: OpQuantizationConfig,
tpc_patch_version=0,
tpc_platform_type=IMX500_TP_MODEL,
name=name,
add_metadata=False)
add_metadata=False,
is_simd_padding=True)

# To start defining the model's components (such as operator sets, and fusing patterns),
# use 'with' the TargetPlatformModel instance, and create them as below:
Expand All @@ -175,8 +176,6 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# be used for operations that will be attached to this set's label.
# Otherwise, it will be a configure-less set (used in fusing):

generated_tpc.set_simd_padding(is_simd_padding=True)

# May suit for operations like: Dropout, Reshape, etc.
default_qco = tp.get_default_quantization_config_options()
schema.OperatorsSet("NoQuantization",
Expand Down Expand Up @@ -206,9 +205,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# Combine multiple operators into a single operator to avoid quantization between
# them. To do this we define fusing patterns using the OperatorsSets that were created.
# To group multiple sets with regard to fusing, an OperatorSetConcat can be created
activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid)
any_binary = schema.OperatorSetConcat(add, sub, mul, div)
activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh])
activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid])
any_binary = schema.OperatorSetConcat([add, sub, mul, div])

# ------------------- #
# Fusions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# Combine multiple operators into a single operator to avoid quantization between
# them. To do this we define fusing patterns using the OperatorsSets that were created.
# To group multiple sets with regard to fusing, an OperatorSetConcat can be created
activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid)
any_binary = schema.OperatorSetConcat(add, sub, mul, div)
activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh])
activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid])
any_binary = schema.OperatorSetConcat([add, sub, mul, div])

# ------------------- #
# Fusions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# Combine multiple operators into a single operator to avoid quantization between
# them. To do this we define fusing patterns using the OperatorsSets that were created.
# To group multiple sets with regard to fusing, an OperatorSetConcat can be created
activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid)
any_binary = schema.OperatorSetConcat(add, sub, mul, div)
activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh])
activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid])
any_binary = schema.OperatorSetConcat([add, sub, mul, div])

# ------------------- #
# Fusions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def generate_tp_model(default_config: OpQuantizationConfig,
tpc_patch_version=0,
tpc_platform_type=IMX500_TP_MODEL,
add_metadata=True,
name=name)
name=name,
is_simd_padding=True)

# To start defining the model's components (such as operator sets, and fusing patterns),
# use 'with' the TargetPlatformModel instance, and create them as below:
Expand All @@ -177,8 +178,6 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# be used for operations that will be attached to this set's label.
# Otherwise, it will be a configure-less set (used in fusing):

generated_tpm.set_simd_padding(is_simd_padding=True)

# May suit for operations like: Dropout, Reshape, etc.
default_qco = tp.get_default_quantization_config_options()
schema.OperatorsSet("NoQuantization",
Expand Down Expand Up @@ -208,9 +207,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# Combine multiple operators into a single operator to avoid quantization between
# them. To do this we define fusing patterns using the OperatorsSets that were created.
# To group multiple sets with regard to fusing, an OperatorSetConcat can be created
activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid)
any_binary = schema.OperatorSetConcat(add, sub, mul, div)
activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh])
activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid])
any_binary = schema.OperatorSetConcat([add, sub, mul, div])

# ------------------- #
# Fusions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# Combine multiple operators into a single operator to avoid quantization between
# them. To do this we define fusing patterns using the OperatorsSets that were created.
# To group multiple sets with regard to fusing, an OperatorSetConcat can be created
activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid)
any_binary = schema.OperatorSetConcat(add, sub, mul, div)
activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh])
activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid])
any_binary = schema.OperatorSetConcat([add, sub, mul, div])

# ------------------- #
# Fusions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def generate_tp_model(default_config: OpQuantizationConfig,
tpc_patch_version=0,
tpc_platform_type=IMX500_TP_MODEL,
add_metadata=True,
name=name)
name=name,
is_simd_padding=True)

# To start defining the model's components (such as operator sets, and fusing patterns),
# use 'with' the TargetPlatformModel instance, and create them as below:
Expand All @@ -198,8 +199,6 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# be used for operations that will be attached to this set's label.
# Otherwise, it will be a configure-less set (used in fusing):

generated_tpm.set_simd_padding(is_simd_padding=True)

# May suit for operations like: Dropout, Reshape, etc.
default_qco = tp.get_default_quantization_config_options()
schema.OperatorsSet("NoQuantization",
Expand Down Expand Up @@ -231,9 +230,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# Combine multiple operators into a single operator to avoid quantization between
# them. To do this we define fusing patterns using the OperatorsSets that were created.
# To group multiple sets with regard to fusing, an OperatorSetConcat can be created
activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid)
any_binary = schema.OperatorSetConcat(add, sub, mul, div)
activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh])
activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid])
any_binary = schema.OperatorSetConcat([add, sub, mul, div])

# ------------------- #
# Fusions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# Combine multiple operators into a single operator to avoid quantization between
# them. To do this we define fusing patterns using the OperatorsSets that were created.
# To group multiple sets with regard to fusing, an OperatorSetConcat can be created
activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid)
any_binary = schema.OperatorSetConcat(add, sub, mul, div)
activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh])
activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid])
any_binary = schema.OperatorSetConcat([add, sub, mul, div])

# ------------------- #
# Fusions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import List, Tuple

import model_compression_toolkit as mct
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema
from model_compression_toolkit.constants import FLOAT_BITWIDTH
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS, \
IMX500_TP_MODEL
Expand Down Expand Up @@ -235,7 +235,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
tpc_minor_version=4,
tpc_patch_version=0,
tpc_platform_type=IMX500_TP_MODEL,
add_metadata=True, name=name)
add_metadata=True,
name=name,
is_simd_padding=True)

# To start defining the model's components (such as operator sets, and fusing patterns),
# use 'with' the TargetPlatformModel instance, and create them as below:
Expand All @@ -246,8 +248,6 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# be used for operations that will be attached to this set's label.
# Otherwise, it will be a configure-less set (used in fusing):

generated_tpm.set_simd_padding(is_simd_padding=True)

# May suit for operations like: Dropout, Reshape, etc.
default_qco = tp.get_default_quantization_config_options()
schema.OperatorsSet(OPSET_NO_QUANTIZATION,
Expand Down Expand Up @@ -294,11 +294,11 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# Combine multiple operators into a single operator to avoid quantization between
# them. To do this we define fusing patterns using the OperatorsSets that were created.
# To group multiple sets with regard to fusing, an OperatorSetConcat can be created
activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid,
tanh, gelu, hardswish, hardsigmoid)
activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid, tanh, gelu,
hardswish, hardsigmoid)
any_binary = schema.OperatorSetConcat(add, sub, mul, div)
activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid,
tanh, gelu, hardswish, hardsigmoid])
activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid, tanh, gelu,
hardswish, hardsigmoid])
any_binary = schema.OperatorSetConcat([add, sub, mul, div])

# ------------------- #
# Fusions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,11 @@ def generate_tp_model(default_config: OpQuantizationConfig,
fixed_zero_point=-128, fixed_scale=1 / 256))

conv2d = schema.OperatorsSet("Conv2d")
kernel = schema.OperatorSetConcat(conv2d, fc)
kernel = schema.OperatorSetConcat([conv2d, fc])

relu = schema.OperatorsSet("Relu")
elu = schema.OperatorsSet("Elu")
activations_to_fuse = schema.OperatorSetConcat(relu, elu)
activations_to_fuse = schema.OperatorSetConcat([relu, elu])

batch_norm = schema.OperatorsSet("BatchNorm")
bias_add = schema.OperatorsSet("BiasAdd")
Expand Down
14 changes: 6 additions & 8 deletions tests/common_tests/test_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_immutable_tp(self):
with model:
schema.OperatorsSet("opset")
model.operator_set = []
self.assertEqual('Immutable class. Can\'t edit attributes.', str(e.exception))
self.assertEqual("cannot assign to field 'operator_set'", str(e.exception))

def test_default_options_more_than_single_qc(self):
test_qco = schema.QuantizationConfigOptions([TEST_QC, TEST_QC], base_config=TEST_QC)
Expand All @@ -76,8 +76,6 @@ def test_tp_model_show(self):
with tpm:
a = schema.OperatorsSet("opA")

tpm.show()


class OpsetTest(unittest.TestCase):

Expand Down Expand Up @@ -114,7 +112,7 @@ def test_opset_concat(self):
b = schema.OperatorsSet('opset_B',
get_default_quantization_config_options().clone_and_edit(activation_n_bits=2))
schema.OperatorsSet('opset_C') # Just add it without using it in concat
schema.OperatorSetConcat(a, b)
schema.OperatorSetConcat([a, b])
self.assertEqual(len(hm.operator_set), 4)
self.assertTrue(hm.is_opset_in_model("opset_A_opset_B"))
self.assertTrue(hm.get_config_options_by_operators_set('opset_A_opset_B') is None)
Expand All @@ -136,14 +134,14 @@ def test_non_unique_opset(self):
class QCOptionsTest(unittest.TestCase):

def test_empty_qc_options(self):
with self.assertRaises(AssertionError) as e:
with self.assertRaises(Exception) as e:
schema.QuantizationConfigOptions([])
self.assertEqual(
"'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.",
str(e.exception))

def test_list_of_no_qc(self):
with self.assertRaises(AssertionError) as e:
with self.assertRaises(Exception) as e:
schema.QuantizationConfigOptions([TEST_QC, 3])
self.assertEqual(
'Each option must be an instance of \'OpQuantizationConfig\', but found an object of type: <class \'int\'>.',
Expand Down Expand Up @@ -186,7 +184,7 @@ def test_fusing_single_opset(self):
add = schema.OperatorsSet("add")
with self.assertRaises(Exception) as e:
schema.Fusing([add])
self.assertEqual('Fusing can not be created for a single operators group', str(e.exception))
self.assertEqual('Fusing cannot be created for a single operator.', str(e.exception))

def test_fusing_contains(self):
hm = schema.TargetPlatformModel(
Expand Down Expand Up @@ -220,7 +218,7 @@ def test_fusing_contains_with_opset_concat(self):
conv = schema.OperatorsSet("conv")
add = schema.OperatorsSet("add")
tanh = schema.OperatorsSet("tanh")
add_tanh = schema.OperatorSetConcat(add, tanh)
add_tanh = schema.OperatorSetConcat([add, tanh])
schema.Fusing([conv, add])
schema.Fusing([conv, add_tanh])
schema.Fusing([conv, add, tanh])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
swish = schema.OperatorsSet("Swish")
sigmoid = schema.OperatorsSet("Sigmoid")
tanh = schema.OperatorsSet("Tanh")
activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid)
any_binary = schema.OperatorSetConcat(add, sub, mul, div)
activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh])
activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid])
any_binary = schema.OperatorSetConcat([add, sub, mul, div])
schema.Fusing([conv, activations_after_conv_to_fuse])
schema.Fusing([fc, activations_after_fc_to_fuse])
schema.Fusing([any_binary, any_relu])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import replace

import numpy as np
import tensorflow as tf

Expand All @@ -34,8 +36,8 @@ def get_tpc(self):
tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v4')
# Force Mul base_config to 16bit only
mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set)
mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0]
tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config
base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0]
tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config)
return tpc

def create_networks(self):
Expand Down Expand Up @@ -67,8 +69,8 @@ class Activation16BitMixedPrecisionTest(Activation16BitTest):
def get_tpc(self):
tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v3')
mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set)
mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0]
tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config
base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0]
tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config)
mul_op_set.qc_options.quantization_config_list.extend(
[mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4),
mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)])
Expand Down
Loading

0 comments on commit 951568c

Please sign in to comment.