From 7cd769074beada424291ba917b1f92e8dc4201b2 Mon Sep 17 00:00:00 2001 From: Ariel Lapid <57916763+lapid92@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:41:17 +0300 Subject: [PATCH] Change TPC dictionary by fw into function (#1152) * Change tpc fw dicts into func --- .../get_target_platform_capabilities.py | 23 ++-- .../target_platform_capabilities.py | 120 ++++++++++-------- .../target_platform_capabilities.py | 64 +++++----- .../target_platform_capabilities.py | 63 +++++---- 4 files changed, 148 insertions(+), 122 deletions(-) diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py index 1aff6c6bb..289d80e40 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py @@ -16,17 +16,17 @@ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.target_platform_capabilities import \ - tpc_dict as imx500_tpc_dict + get_tpc_dict_by_fw as get_imx500_tpc from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.target_platform_capabilities import \ - tpc_dict as tflite_tpc_dict + get_tpc_dict_by_fw as get_tflite_tpc from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.target_platform_capabilities import \ - tpc_dict as qnnpack_tpc_dict + get_tpc_dict_by_fw as get_qnnpack_tpc from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, TFLITE_TP_MODEL, QNNPACK_TP_MODEL, LATEST -tpc_dict = {DEFAULT_TP_MODEL: imx500_tpc_dict, - IMX500_TP_MODEL: imx500_tpc_dict, - TFLITE_TP_MODEL: tflite_tpc_dict, - QNNPACK_TP_MODEL: qnnpack_tpc_dict} +tpc_dict = {DEFAULT_TP_MODEL: get_imx500_tpc, + IMX500_TP_MODEL: get_imx500_tpc, + TFLITE_TP_MODEL: get_tflite_tpc, + QNNPACK_TP_MODEL: get_qnnpack_tpc} def get_target_platform_capabilities(fw_name: str, @@ -47,13 +47,10 @@ def get_target_platform_capabilities(fw_name: str, """ assert target_platform_name in tpc_dict, f'Target platform {target_platform_name} is not defined!' fw_tpc = tpc_dict.get(target_platform_name) - assert fw_name in fw_tpc, f'Framework {fw_name} is not supported in {target_platform_name}. Please make sure the relevant ' \ - f'packages are installed when using MCT for optimizing a {fw_name} model. ' \ - f'For Tensorflow, please install tensorflow. ' \ - f'For PyTorch, please install torch.' - tpc_versions = fw_tpc.get(fw_name) + tpc_versions = fw_tpc(fw_name) if target_platform_version is None: target_platform_version = LATEST else: - assert target_platform_version in tpc_versions, f'TPC version {target_platform_version} is not supported for framework {fw_name}.' + assert target_platform_version in tpc_versions, (f'TPC version {target_platform_version} is not supported for ' + f'framework {fw_name}.') return tpc_versions[target_platform_version]() diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py index 00cc06403..e48b4548c 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py @@ -12,61 +12,74 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH from model_compression_toolkit.target_platform_capabilities.constants import LATEST -############################### -# Build Tensorflow TPC models -############################### -keras_tpc_models_dict = None -if FOUND_TF: - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_keras_tpc_latest - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1 - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v1_lut - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_keras import get_keras_tpc as get_keras_tpc_v1_pot - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import get_keras_tpc as get_keras_tpc_v2 - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v2_lut - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import get_keras_tpc as get_keras_tpc_v3 - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v3_lut - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import get_keras_tpc as get_keras_tpc_v4 - # Keras: TPC versioning - keras_tpc_models_dict = {'v1': get_keras_tpc_v1, - 'v1_lut': get_keras_tpc_v1_lut, - 'v1_pot': get_keras_tpc_v1_pot, - 'v2': get_keras_tpc_v2, - 'v2_lut': get_keras_tpc_v2_lut, - 'v3': get_keras_tpc_v3, - 'v3_lut': get_keras_tpc_v3_lut, - 'v4': get_keras_tpc_v4, - LATEST: get_keras_tpc_latest} +def get_tpc_dict_by_fw(fw_name): + tpc_models_dict = None + if fw_name == TENSORFLOW: + ############################### + # Build Tensorflow TPC models + ############################### + if FOUND_TF: + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \ + get_keras_tpc_latest + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v1 + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v1_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v1_pot + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v2 + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v2_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v3 + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v3_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v4 -############################### -# Build Pytorch TPC models -############################### -pytorch_tpc_models_dict = None -if FOUND_TORCH: - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_pytorch_tpc_latest - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import \ - get_pytorch_tpc as get_pytorch_tpc_v1 - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_pytorch import \ - get_pytorch_tpc as get_pytorch_tpc_v1_pot - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import \ - get_pytorch_tpc as get_pytorch_tpc_v1_lut - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import \ - get_pytorch_tpc as get_pytorch_tpc_v2 - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_pytorch import \ - get_pytorch_tpc as get_pytorch_tpc_v2_lut - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_pytorch import \ - get_pytorch_tpc as get_pytorch_tpc_v3 - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \ - get_pytorch_tpc as get_pytorch_tpc_v3_lut - from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \ - get_pytorch_tpc as get_pytorch_tpc_v4 + # Keras: TPC versioning + tpc_models_dict = {'v1': get_keras_tpc_v1, + 'v1_lut': get_keras_tpc_v1_lut, + 'v1_pot': get_keras_tpc_v1_pot, + 'v2': get_keras_tpc_v2, + 'v2_lut': get_keras_tpc_v2_lut, + 'v3': get_keras_tpc_v3, + 'v3_lut': get_keras_tpc_v3_lut, + 'v4': get_keras_tpc_v4, + LATEST: get_keras_tpc_latest} + elif fw_name == PYTORCH: + ############################### + # Build Pytorch TPC models + ############################### + if FOUND_TORCH: + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \ + get_pytorch_tpc_latest + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v1 + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v1_pot + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v1_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v2 + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v2_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v3 + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v3_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v4 - # Pytorch: TPC versioning - pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1, + # Pytorch: TPC versioning + tpc_models_dict = {'v1': get_pytorch_tpc_v1, 'v1_lut': get_pytorch_tpc_v1_lut, 'v1_pot': get_pytorch_tpc_v1_pot, 'v2': get_pytorch_tpc_v2, @@ -75,7 +88,10 @@ 'v3_lut': get_pytorch_tpc_v3_lut, 'v4': get_pytorch_tpc_v4, LATEST: get_pytorch_tpc_latest} - -tpc_dict = {TENSORFLOW: keras_tpc_models_dict, - PYTORCH: pytorch_tpc_models_dict} - + if tpc_models_dict is not None: + return tpc_models_dict + else: + Logger.critical(f'Framework {fw_name} is not supported in imx500 or the relevant packages are not ' + f'installed. Please make sure the relevant packages are installed when using MCT for optimizing' + f' a {fw_name} model. For Tensorflow, please install tensorflow. For PyTorch, please install ' + f'torch.') # pragma: no cover diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py index 4bd771c1b..fba93e091 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py @@ -14,35 +14,41 @@ # ============================================================================== from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH +from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.constants import LATEST - -############################### -# Build Tensorflow TPC models -############################### -keras_tpc_models_dict = None -if FOUND_TF: - from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1 - from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import get_keras_tpc_latest - - # Keras: TPC versioning - keras_tpc_models_dict = {'v1': get_keras_tpc_v1, - LATEST: get_keras_tpc_latest} - -############################### -# Build Pytorch TPC models -############################### -pytorch_tpc_models_dict = None -if FOUND_TORCH: - from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_pytorch import \ - get_pytorch_tpc as get_pytorch_tpc_v1 - from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import get_pytorch_tpc_latest - - # Pytorch: TPC versioning - pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1, +def get_tpc_dict_by_fw(fw_name): + tpc_models_dict = None + if fw_name == TENSORFLOW: + ############################### + # Build Tensorflow TPC models + ############################### + if FOUND_TF: + from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v1 + from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import \ + get_keras_tpc_latest + + # Keras: TPC versioning + tpc_models_dict = {'v1': get_keras_tpc_v1, + LATEST: get_keras_tpc_latest} + elif fw_name == PYTORCH: + ############################### + # Build Pytorch TPC models + ############################### + if FOUND_TORCH: + from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v1 + from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import \ + get_pytorch_tpc_latest + + # Pytorch: TPC versioning + tpc_models_dict = {'v1': get_pytorch_tpc_v1, LATEST: get_pytorch_tpc_latest} - -tpc_dict = {TENSORFLOW: keras_tpc_models_dict, - PYTORCH: pytorch_tpc_models_dict} - - + if tpc_models_dict is not None: + return tpc_models_dict + else: + Logger.critical(f'Framework {fw_name} is not supported in imx500 or the relevant packages are not ' + f'installed. Please make sure the relevant packages are installed when using MCT for optimizing' + f' a {fw_name} model. For Tensorflow, please install tensorflow. For PyTorch, please install ' + f'torch.') # pragma: no cover diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py index c543bfd2c..592591510 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py @@ -14,34 +14,41 @@ # ============================================================================== from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH +from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.constants import LATEST - -############################### -# Build Tensorflow TPC models -############################### -keras_tpc_models_dict = None -if FOUND_TF: - from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1 - from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import get_keras_tpc_latest - - # Keras: TPC versioning - keras_tpc_models_dict = {'v1': get_keras_tpc_v1, - LATEST: get_keras_tpc_latest} - -############################### -# Build Pytorch TPC models -############################### -pytorch_tpc_models_dict = None -if FOUND_TORCH: - from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import \ - get_pytorch_tpc as get_pytorch_tpc_v1 - from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import get_pytorch_tpc_latest - - # Pytorch: TPC versioning - pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1, +def get_tpc_dict_by_fw(fw_name): + tpc_models_dict = None + if fw_name == TENSORFLOW: + ############################### + # Build Tensorflow TPC models + ############################### + if FOUND_TF: + from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import \ + get_keras_tpc as get_keras_tpc_v1 + from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import \ + get_keras_tpc_latest + + # Keras: TPC versioning + tpc_models_dict = {'v1': get_keras_tpc_v1, + LATEST: get_keras_tpc_latest} + elif fw_name == PYTORCH: + ############################### + # Build Pytorch TPC models + ############################### + if FOUND_TORCH: + from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v1 + from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import \ + get_pytorch_tpc_latest + + # Pytorch: TPC versioning + tpc_models_dict = {'v1': get_pytorch_tpc_v1, LATEST: get_pytorch_tpc_latest} - -tpc_dict = {TENSORFLOW: keras_tpc_models_dict, - PYTORCH: pytorch_tpc_models_dict} - + if tpc_models_dict is not None: + return tpc_models_dict + else: + Logger.critical(f'Framework {fw_name} is not supported in imx500 or the relevant packages are not ' + f'installed. Please make sure the relevant packages are installed when using MCT for optimizing' + f' a {fw_name} model. For Tensorflow, please install tensorflow. For PyTorch, please install ' + f'torch.') # pragma: no cover