Skip to content

Commit

Permalink
Support and fix merge layers (e.g. concat) (#1215)
Browse files Browse the repository at this point in the history
* Support Keras Merge layers (Add, Concatenate, etc.) in model builder & quantization.
* Fix for TF Merge functions (e.g. tf.concat).
* Add PyTorch tests and fix model reader and builder to support quantized merge ops (e.g. torch.cat)
  • Loading branch information
elad-c authored Sep 17, 2024
1 parent 4dac6a3 commit 2fd976c
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 43 deletions.
3 changes: 3 additions & 0 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self,
layer_class: type,
reuse: bool = False,
reuse_group: str = None,
inputs_as_list: bool = False,
quantization_attr: Dict[str, Any] = None,
has_activation: bool = True,
is_custom: bool = False
Expand All @@ -58,6 +59,7 @@ def __init__(self,
layer_class: Class path of the layer this node represents.
reuse: Whether this node was duplicated and represents a reused layer.
reuse_group: Name of group of nodes from the same reused layer.
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
quantization_attr: Attributes the node holds regarding how it should be quantized.
has_activation: Whether the node has activations that we might want to quantize.
is_custom: Whether the node is custom layer or not.
Expand All @@ -71,6 +73,7 @@ def __init__(self,
self.layer_class = layer_class
self.reuse = reuse
self.reuse_group = reuse_group
self.inputs_as_list = inputs_as_list
self.final_weights_quantization_cfg = None
self.final_activation_quantization_cfg = None
self.candidates_quantization_cfg = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def __init__(self,
layer_class,
reuse,
reuse_group,
inputs_as_list,
quantization_attr,
has_activation=has_activation)

self.op_call_kwargs = op_call_kwargs
self.op_call_args = list(op_call_args)
self.functional_op = functional_op
self.inputs_as_list = inputs_as_list
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _run_operation(self,
else:
# If operator expects a single input tensor, it cannot be a list as it should
# have a dtype field.
if len(input_tensors) == 1:
if len(input_tensors) == 1 and not n.inputs_as_list:
input_tensors = input_tensors[0]
out_tensors_of_n_float = op_func(input_tensors)

Expand Down
24 changes: 23 additions & 1 deletion model_compression_toolkit/core/keras/reader/node_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
from keras.src.layers.core import TFOpLambda, SlicingOpLambda
from keras.src.engine.keras_tensor import KerasTensor
from keras.src.engine.node import Node as KerasNode
from keras.src.layers.merging.base_merge import _Merge
else:
from keras.layers.core import TFOpLambda, SlicingOpLambda
from keras.engine.keras_tensor import KerasTensor
from keras.engine.node import Node as KerasNode
from keras.layers.merging.base_merge import _Merge

from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
Expand Down Expand Up @@ -287,6 +289,7 @@ def build_node(node: KerasNode,
for i, arg in enumerate(op_call_args[0]):
if is_const(arg):
weights.update({i: to_numpy(arg, is_single_tensor=True)})
inputs_as_list = __is_node_inputs_a_list(op_call_args, keras_layer)

node = BaseNode(node_name,
layer_config,
Expand All @@ -296,6 +299,7 @@ def build_node(node: KerasNode,
layer_class,
is_reused,
reuse_group,
inputs_as_list,
is_custom=is_keras_custom_layer(layer_class))

node_name_to_node[node_name] = node
Expand All @@ -316,6 +320,24 @@ def __is_functional_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
"""

return (keras_layer.symbol in
[TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol,TFOpLambda(tf.add_n).symbol] and
[TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol, TFOpLambda(tf.add_n).symbol] and
len(op_call_args) > 0 and
isinstance(op_call_args[0], list))


def __is_node_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
"""
Check whether the input tensors should be passed as a list or not. This is relevant
only for layers that inherit from _Merge such as Concatenate and Add.
Args:
op_call_args: Arguments list to check.
keras_layer: Keras layer.
Returns:
Whether the input tensors should be passed as a list or not.
"""

return (isinstance(keras_layer, _Merge) and
len(op_call_args) > 0 and
isinstance(op_call_args[0], (list, tuple)))
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def _run_operation(n: BaseNode,
_tensor_input_allocs = None

if isinstance(n, FunctionalNode) and n.inputs_as_list:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
if isinstance(op_func, PytorchQuantizationWrapper):
# in wrapped nodes, the op args & kwargs are already in the PytorchQuantizationWrapper.
out_tensors_of_n_float = op_func(*input_tensors)
else:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
else:
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
tensor_input_allocs=_tensor_input_allocs)
Expand Down
17 changes: 13 additions & 4 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,19 @@ def nodes_builder(model: GraphModule,

# Add constants to weights dictionary.
if node.op != PLACEHOLDER:
for i, input_node in enumerate(node.all_input_nodes):
if input_node in consts_dict:
used_consts.add(input_node)
weights.update({i: consts_dict[input_node]})
if len(node.args) and isinstance(node.args[0], (list, tuple)):
# handle weights in nodes with list input. Especially when there's a duplicate of a tensor
# in the input list (e.g. torch.concat([const1, x, const2, x, const3], 1)).
for input_node in node.all_input_nodes:
for i, input_arg in enumerate(node.args[0]):
if input_node is input_arg and input_node in consts_dict:
used_consts.add(input_node)
weights.update({i: consts_dict[input_node]})
else:
for i, input_node in enumerate(node.all_input_nodes):
if input_node in consts_dict:
used_consts.add(input_node)
weights.update({i: consts_dict[input_node]})

# Extract input and output shapes of the node.
input_shape, output_shape = _extract_input_and_output_shapes(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import Tuple, Callable
from typing import Tuple, Callable, Union
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.verify_packages import FOUND_TF
Expand All @@ -25,10 +25,12 @@
import tensorflow as tf
from tensorflow.keras.layers import Layer
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from mct_quantizers import KerasQuantizationWrapper
from mct_quantizers import KerasActivationQuantizationHolder
from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS

def _get_wrapper(node: common.BaseNode,
def _get_wrapper(node: Union[common.BaseNode, FunctionalNode],
layer: Layer,
fw_impl=None) -> Layer:
"""
Expand All @@ -45,9 +47,16 @@ def _get_wrapper(node: common.BaseNode,
# for positional weights we need to extract the weight's value.
weights_values = {attr: node.get_weights_by_keys(attr)
for attr in weights_quantizers if isinstance(attr, int)}
# When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
# are used during wrapper call method.
func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
OP_CALL_KWARGS: node.op_call_kwargs
} if isinstance(node, FunctionalNode) else {}
return KerasQuantizationWrapper(layer,
weights_quantizers,
weights_values)
weights_values,
is_inputs_as_list=node.inputs_as_list,
**func_node_kwargs)
return layer


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
if FOUND_TORCH:
import torch
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode


def fully_quantized_wrapper(node: common.BaseNode,
Expand All @@ -46,7 +48,14 @@ def fully_quantized_wrapper(node: common.BaseNode,
# for positional weights we need to extract the weight's value.
weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr))
for attr in weight_quantizers if isinstance(attr, int)}
return PytorchQuantizationWrapper(module, weight_quantizers, weights_values)
# When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
# are used during wrapper call method.
func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
OP_CALL_KWARGS: node.op_call_kwargs
} if isinstance(node, FunctionalNode) else {}
return PytorchQuantizationWrapper(module, weight_quantizers, weights_values,
is_inputs_as_list=node.inputs_as_list,
**func_node_kwargs)
return module


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@

import model_compression_toolkit as mct
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tp_model import generate_tp_model, \
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import generate_tp_model, \
get_op_quantization_configs
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import generate_keras_tpc
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import generate_keras_tpc
from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, DEFAULT_WEIGHT_ATTR_CONFIG, \
generate_test_tp_model, generate_custom_test_tp_model
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
from tests.common_tests.helpers.tensors_compare import cosine_similarity
from tests.keras_tests.utils import get_layers_from_model_by_type
from mct_quantizers import KerasQuantizationWrapper, QuantizationMethod

from model_compression_toolkit.constants import TENSORFLOW
Expand All @@ -35,6 +36,39 @@
tp = mct.target_platform


def create_const_quant_tpc(qmethod):
name = "const_quant_tpc"
base_cfg, mp_op_cfg_list, default_cfg = get_op_quantization_configs()
base_tp_model = generate_tp_model(default_config=default_cfg,
base_config=base_cfg,
mixed_precision_cfg_list=mp_op_cfg_list,
name=name)

const_config = default_cfg.clone_and_edit(
default_weight_attr_config=default_cfg.default_weight_attr_config.clone_and_edit(
enable_weights_quantization=True, weights_per_channel_threshold=True,
weights_n_bits=16, weights_quantization_method=qmethod))
const_configuration_options = tp.QuantizationConfigOptions([const_config])
const_merge_config = default_cfg.clone_and_edit(
default_weight_attr_config=default_cfg.default_weight_attr_config.clone_and_edit(
weights_per_channel_threshold=False))
const_merge_configuration_options = tp.QuantizationConfigOptions([const_merge_config])

operator_sets_dict = {}
operator_sets_dict["Add"] = const_configuration_options
operator_sets_dict["Sub"] = const_configuration_options
operator_sets_dict["Mul"] = const_configuration_options
operator_sets_dict["Div"] = const_configuration_options
operator_sets_dict["MergeOps"] = const_merge_configuration_options

tp_model = generate_custom_test_tp_model(name=name,
base_cfg=base_cfg,
base_tp_model=base_tp_model,
operator_sets_dict=operator_sets_dict)

return generate_keras_tpc(name="const_quant_tpc", tp_model=tp_model)


class ConstQuantizationTest(BaseKerasFeatureNetworkTest):

def __init__(self, unit_test, layer, const, is_list_input=False, input_reverse_order=False, use_kwargs=False,
Expand All @@ -58,31 +92,7 @@ def get_quantization_config(self):
return mct.core.QuantizationConfig(weights_error_method=self.error_method)

def get_tpc(self):
name = "const_quant_tpc"
base_cfg, mp_op_cfg_list, default_cfg = get_op_quantization_configs()
base_tp_model = generate_tp_model(default_config=default_cfg,
base_config=base_cfg,
mixed_precision_cfg_list=mp_op_cfg_list,
name=name)

const_config = default_cfg.clone_and_edit(
default_weight_attr_config=default_cfg.default_weight_attr_config.clone_and_edit(
enable_weights_quantization=True, weights_per_channel_threshold=True,
weights_quantization_method=self.qmethod))
const_configuration_options = tp.QuantizationConfigOptions([const_config])

operator_sets_dict = {}
operator_sets_dict["Add"] = const_configuration_options
operator_sets_dict["Sub"] = const_configuration_options
operator_sets_dict["Mul"] = const_configuration_options
operator_sets_dict["Div"] = const_configuration_options

tp_model = generate_custom_test_tp_model(name=name,
base_cfg=base_cfg,
base_tp_model=base_tp_model,
operator_sets_dict=operator_sets_dict)

return generate_keras_tpc(name="const_quant_tpc", tp_model=tp_model)
return create_const_quant_tpc(self.qmethod)

def create_networks(self):
inputs = layers.Input(shape=self.get_input_shapes()[0][1:])
Expand Down Expand Up @@ -159,3 +169,37 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
msg='TFOpLambda should be quantized')
self.unit_test.assertTrue((quantized_model.layers[5].weight_values[1] == self.const).all(),
msg='Constant value should not change')


class ConstQuantizationMultiInputTest(BaseKerasFeatureNetworkTest):

def __init__(self, unit_test, input_shape=(32, 32, 16)):
super(ConstQuantizationMultiInputTest, self).__init__(unit_test=unit_test, input_shape=input_shape)

def get_tpc(self):
return mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v4")

def create_networks(self):
as_const = lambda v: np.random.random(v.shape.as_list()).astype(np.float32)
inputs = layers.Input(shape=self.get_input_shapes()[0][1:])
x = layers.Concatenate()([inputs, np.random.random((1, 32, 32, 3)),
inputs, np.random.random((1, 32, 32, 3))])
x1 = layers.Add()([np.random.random((1, x.shape[-1])), x, np.random.random((1, x.shape[-1]))])
x2 = layers.Multiply()([x, np.random.random((1, x.shape[-1])), x, np.random.random((1, x.shape[-1]))])
x3 = tf.add_n([x1, as_const(x), x2])
x1 = tf.reshape(tf.stack([as_const(x1), x1, as_const(x1)], axis=1), (-1, 3*x1.shape[1], x1.shape[2], x1.shape[3]))
x = tf.concat([x1, x2, as_const(x3), x3], 1)
return tf.keras.models.Model(inputs=inputs, outputs=x)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1, atol=1e-2), msg=f'fail cosine similarity check:{cs}')

# check quantization layers:
for op in [tf.concat, tf.stack, layers.Add, layers.Multiply, layers.Concatenate]:
for qlayer in get_layers_from_model_by_type(quantized_model, op):
self.unit_test.assertTrue(isinstance(qlayer, KerasQuantizationWrapper),
msg=f"{op} should be quantized.")
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,14 @@ def get_tpc(self):
return generate_keras_tpc(name="const_representation_test", tp_model=tp)

def create_networks(self):
as_const = lambda v: np.random.random(v.shape.as_list()).astype(np.float32)
inputs = layers.Input(shape=self.get_input_shapes()[0][1:])
x = layers.Concatenate()([inputs, np.random.random((1, 32, 32, 3)), inputs, np.random.random((1, 32, 32, 3))])
x1 = layers.Add()([np.random.random((1, x.shape[-1])), x, np.random.random((1, x.shape[-1]))])
x2 = layers.Multiply()([x, np.random.random((1, x.shape[-1])), x, np.random.random((1, x.shape[-1]))])
x3 = tf.add_n([x1, np.random.random(x.shape.as_list()).astype(np.float32), x2])
x = tf.concat([x1, x2, np.random.random(x3.shape.as_list()).astype(np.float32), x3], 1)
x3 = tf.add_n([x1, as_const(x), x2])
x1 = tf.reshape(tf.stack([as_const(x1), x1, as_const(x1)], axis=1), (-1, 3*x1.shape[1], x1.shape[2], x1.shape[3]))
x = tf.concat([x1, x2, as_const(x3), x3], 1)
return tf.keras.models.Model(inputs=inputs, outputs=x)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
ConstRepresentationMultiInputTest, ConstRepresentationMatMulTest, ConstRepresentationListTypeArgsTest
from tests.keras_tests.feature_networks_tests.feature_networks.concatination_threshold_update import ConcatThresholdtest
from tests.keras_tests.feature_networks_tests.feature_networks.const_quantization_test import ConstQuantizationTest, \
AdvancedConstQuantizationTest
AdvancedConstQuantizationTest, ConstQuantizationMultiInputTest
from tests.keras_tests.feature_networks_tests.feature_networks.activation_16bit_test import Activation16BitTest, \
Activation16BitMixedPrecisionTest
from tests.keras_tests.feature_networks_tests.feature_networks.sigmoid_mul_substitution_test import SigMulSubstitutionTest
Expand Down Expand Up @@ -588,6 +588,7 @@ def test_const_quantization(self):
ConstQuantizationTest(self, func, 5.1, input_reverse_order=True, qmethod=qmethod, error_method=error_method).run_test()

AdvancedConstQuantizationTest(self).run_test()
ConstQuantizationMultiInputTest(self).run_test()

def test_const_representation(self):
c = (np.ones((16,)) + np.random.random((16,))).astype(np.float32)
Expand Down
Loading

0 comments on commit 2fd976c

Please sign in to comment.