Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TPC.v4 with quantization preservation. #1214

Merged
merged 4 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,72 @@ def set_quantization_configuration_to_graph(graph: Graph,
return graph


def filter_node_qco_by_graph(node: BaseNode,
tpc: TargetPlatformCapabilities,
graph: Graph,
node_qc_options: QuantizationConfigOptions
) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
elad-c marked this conversation as resolved.
Show resolved Hide resolved
"""
Filter quantization config options that don't match the graph.
A node may have several quantization config options with 'activation_n_bits' values, and
the next nodes in the graph may support different bit-width as input activation. This function
filters out quantization config that don't comply to these attributes.

Args:
node: Node for filtering.
tpc: TPC to extract the QuantizationConfigOptions for the next nodes.
graph: Graph object.
node_qc_options: Node's QuantizationConfigOptions.

Returns:
A base config (OpQuantizationConfig) and a config options list (list of OpQuantizationConfig)
that are compatible with next nodes supported input bit-widths.

"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list

# Build next_nodes list by appending to the node's next nodes list all nodes that are quantization preserving.
_next_nodes = graph.get_next_nodes(node)
next_nodes = []
while len(_next_nodes):
n = _next_nodes.pop(0)
qco = n.get_qco(tpc)
qp = [qc.quantization_preserving for qc in qco.quantization_config_list]
if not all(qp) and any(qp):
Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
if qp[0]:
_next_nodes.extend(graph.get_next_nodes(n))
next_nodes.append(n)

if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
if len(_node_qc_options) == 0:
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.")

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Logger.warning(f"Node {node} base quantization config changed to match Graph and TPC configuration.\nCause: {node} -> {next_nodes}.")
else:
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.") # pragma: no cover

return _base_config, _node_qc_options


def set_quantization_configs_to_node(node: BaseNode,
graph: Graph,
quant_config: QuantizationConfig,
Expand All @@ -99,7 +165,7 @@ def set_quantization_configs_to_node(node: BaseNode,
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
"""
node_qc_options = node.get_qco(tpc)
base_config, node_qc_options_list = node.filter_node_qco_by_graph(tpc, graph.get_next_nodes(node), node_qc_options)
base_config, node_qc_options_list = filter_node_qco_by_graph(node, tpc, graph, node_qc_options)

# If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
# and update base_config accordingly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def get_tpc_dict_by_fw(fw_name):
get_keras_tpc as get_keras_tpc_v3
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import \
get_keras_tpc as get_keras_tpc_v3_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import \
get_keras_tpc as get_keras_tpc_v4

# Keras: TPC versioning
tpc_models_dict = {'v1': get_keras_tpc_v1,
Expand All @@ -51,6 +53,7 @@ def get_tpc_dict_by_fw(fw_name):
'v2_lut': get_keras_tpc_v2_lut,
'v3': get_keras_tpc_v3,
'v3_lut': get_keras_tpc_v3_lut,
'v4': get_keras_tpc_v4,
LATEST: get_keras_tpc_latest}
elif fw_name == PYTORCH:
###############################
Expand All @@ -73,6 +76,8 @@ def get_tpc_dict_by_fw(fw_name):
get_pytorch_tpc as get_pytorch_tpc_v3
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v3_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v4

# Pytorch: TPC versioning
tpc_models_dict = {'v1': get_pytorch_tpc_v1,
Expand All @@ -82,6 +87,7 @@ def get_tpc_dict_by_fw(fw_name):
'v2_lut': get_pytorch_tpc_v2_lut,
'v3': get_pytorch_tpc_v3,
'v3_lut': get_pytorch_tpc_v3_lut,
'v4': get_pytorch_tpc_v4,
LATEST: get_pytorch_tpc_latest}
if tpc_models_dict is not None:
return tpc_models_dict
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

__version__ = 'v4'
Loading
Loading