Skip to content

Commit

Permalink
Change TPC dictionary by fw into function (#1152)
Browse files Browse the repository at this point in the history
* Change tpc fw dicts into func
  • Loading branch information
lapid92 authored Aug 7, 2024
1 parent c70c464 commit 7cd7690
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities

from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.target_platform_capabilities import \
tpc_dict as imx500_tpc_dict
get_tpc_dict_by_fw as get_imx500_tpc
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.target_platform_capabilities import \
tpc_dict as tflite_tpc_dict
get_tpc_dict_by_fw as get_tflite_tpc
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.target_platform_capabilities import \
tpc_dict as qnnpack_tpc_dict
get_tpc_dict_by_fw as get_qnnpack_tpc
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, TFLITE_TP_MODEL, QNNPACK_TP_MODEL, LATEST

tpc_dict = {DEFAULT_TP_MODEL: imx500_tpc_dict,
IMX500_TP_MODEL: imx500_tpc_dict,
TFLITE_TP_MODEL: tflite_tpc_dict,
QNNPACK_TP_MODEL: qnnpack_tpc_dict}
tpc_dict = {DEFAULT_TP_MODEL: get_imx500_tpc,
IMX500_TP_MODEL: get_imx500_tpc,
TFLITE_TP_MODEL: get_tflite_tpc,
QNNPACK_TP_MODEL: get_qnnpack_tpc}


def get_target_platform_capabilities(fw_name: str,
Expand All @@ -47,13 +47,10 @@ def get_target_platform_capabilities(fw_name: str,
"""
assert target_platform_name in tpc_dict, f'Target platform {target_platform_name} is not defined!'
fw_tpc = tpc_dict.get(target_platform_name)
assert fw_name in fw_tpc, f'Framework {fw_name} is not supported in {target_platform_name}. Please make sure the relevant ' \
f'packages are installed when using MCT for optimizing a {fw_name} model. ' \
f'For Tensorflow, please install tensorflow. ' \
f'For PyTorch, please install torch.'
tpc_versions = fw_tpc.get(fw_name)
tpc_versions = fw_tpc(fw_name)
if target_platform_version is None:
target_platform_version = LATEST
else:
assert target_platform_version in tpc_versions, f'TPC version {target_platform_version} is not supported for framework {fw_name}.'
assert target_platform_version in tpc_versions, (f'TPC version {target_platform_version} is not supported for '
f'framework {fw_name}.')
return tpc_versions[target_platform_version]()
Original file line number Diff line number Diff line change
Expand Up @@ -12,61 +12,74 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.logger import Logger

from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH
from model_compression_toolkit.target_platform_capabilities.constants import LATEST

###############################
# Build Tensorflow TPC models
###############################
keras_tpc_models_dict = None
if FOUND_TF:
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_keras_tpc_latest
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v1_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_keras import get_keras_tpc as get_keras_tpc_v1_pot
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import get_keras_tpc as get_keras_tpc_v2
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v2_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import get_keras_tpc as get_keras_tpc_v3
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v3_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import get_keras_tpc as get_keras_tpc_v4

# Keras: TPC versioning
keras_tpc_models_dict = {'v1': get_keras_tpc_v1,
'v1_lut': get_keras_tpc_v1_lut,
'v1_pot': get_keras_tpc_v1_pot,
'v2': get_keras_tpc_v2,
'v2_lut': get_keras_tpc_v2_lut,
'v3': get_keras_tpc_v3,
'v3_lut': get_keras_tpc_v3_lut,
'v4': get_keras_tpc_v4,
LATEST: get_keras_tpc_latest}
def get_tpc_dict_by_fw(fw_name):
tpc_models_dict = None
if fw_name == TENSORFLOW:
###############################
# Build Tensorflow TPC models
###############################
if FOUND_TF:
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \
get_keras_tpc_latest
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import \
get_keras_tpc as get_keras_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_keras import \
get_keras_tpc as get_keras_tpc_v1_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_keras import \
get_keras_tpc as get_keras_tpc_v1_pot
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import \
get_keras_tpc as get_keras_tpc_v2
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import \
get_keras_tpc as get_keras_tpc_v2_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import \
get_keras_tpc as get_keras_tpc_v3
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import \
get_keras_tpc as get_keras_tpc_v3_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import \
get_keras_tpc as get_keras_tpc_v4

###############################
# Build Pytorch TPC models
###############################
pytorch_tpc_models_dict = None
if FOUND_TORCH:
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_pytorch_tpc_latest
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1_pot
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v2
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v2_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v3
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v3_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v4
# Keras: TPC versioning
tpc_models_dict = {'v1': get_keras_tpc_v1,
'v1_lut': get_keras_tpc_v1_lut,
'v1_pot': get_keras_tpc_v1_pot,
'v2': get_keras_tpc_v2,
'v2_lut': get_keras_tpc_v2_lut,
'v3': get_keras_tpc_v3,
'v3_lut': get_keras_tpc_v3_lut,
'v4': get_keras_tpc_v4,
LATEST: get_keras_tpc_latest}
elif fw_name == PYTORCH:
###############################
# Build Pytorch TPC models
###############################
if FOUND_TORCH:
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \
get_pytorch_tpc_latest
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1_pot
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v2
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v2_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v3
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v3_lut
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v4

# Pytorch: TPC versioning
pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1,
# Pytorch: TPC versioning
tpc_models_dict = {'v1': get_pytorch_tpc_v1,
'v1_lut': get_pytorch_tpc_v1_lut,
'v1_pot': get_pytorch_tpc_v1_pot,
'v2': get_pytorch_tpc_v2,
Expand All @@ -75,7 +88,10 @@
'v3_lut': get_pytorch_tpc_v3_lut,
'v4': get_pytorch_tpc_v4,
LATEST: get_pytorch_tpc_latest}

tpc_dict = {TENSORFLOW: keras_tpc_models_dict,
PYTORCH: pytorch_tpc_models_dict}

if tpc_models_dict is not None:
return tpc_models_dict
else:
Logger.critical(f'Framework {fw_name} is not supported in imx500 or the relevant packages are not '
f'installed. Please make sure the relevant packages are installed when using MCT for optimizing'
f' a {fw_name} model. For Tensorflow, please install tensorflow. For PyTorch, please install '
f'torch.') # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,41 @@
# ==============================================================================

from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.constants import LATEST


###############################
# Build Tensorflow TPC models
###############################
keras_tpc_models_dict = None
if FOUND_TF:
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import get_keras_tpc_latest

# Keras: TPC versioning
keras_tpc_models_dict = {'v1': get_keras_tpc_v1,
LATEST: get_keras_tpc_latest}

###############################
# Build Pytorch TPC models
###############################
pytorch_tpc_models_dict = None
if FOUND_TORCH:
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import get_pytorch_tpc_latest

# Pytorch: TPC versioning
pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1,
def get_tpc_dict_by_fw(fw_name):
tpc_models_dict = None
if fw_name == TENSORFLOW:
###############################
# Build Tensorflow TPC models
###############################
if FOUND_TF:
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_keras import \
get_keras_tpc as get_keras_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import \
get_keras_tpc_latest

# Keras: TPC versioning
tpc_models_dict = {'v1': get_keras_tpc_v1,
LATEST: get_keras_tpc_latest}
elif fw_name == PYTORCH:
###############################
# Build Pytorch TPC models
###############################
if FOUND_TORCH:
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import \
get_pytorch_tpc_latest

# Pytorch: TPC versioning
tpc_models_dict = {'v1': get_pytorch_tpc_v1,
LATEST: get_pytorch_tpc_latest}

tpc_dict = {TENSORFLOW: keras_tpc_models_dict,
PYTORCH: pytorch_tpc_models_dict}


if tpc_models_dict is not None:
return tpc_models_dict
else:
Logger.critical(f'Framework {fw_name} is not supported in imx500 or the relevant packages are not '
f'installed. Please make sure the relevant packages are installed when using MCT for optimizing'
f' a {fw_name} model. For Tensorflow, please install tensorflow. For PyTorch, please install '
f'torch.') # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,41 @@
# ==============================================================================

from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.constants import LATEST


###############################
# Build Tensorflow TPC models
###############################
keras_tpc_models_dict = None
if FOUND_TF:
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import get_keras_tpc_latest

# Keras: TPC versioning
keras_tpc_models_dict = {'v1': get_keras_tpc_v1,
LATEST: get_keras_tpc_latest}

###############################
# Build Pytorch TPC models
###############################
pytorch_tpc_models_dict = None
if FOUND_TORCH:
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import get_pytorch_tpc_latest

# Pytorch: TPC versioning
pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1,
def get_tpc_dict_by_fw(fw_name):
tpc_models_dict = None
if fw_name == TENSORFLOW:
###############################
# Build Tensorflow TPC models
###############################
if FOUND_TF:
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import \
get_keras_tpc as get_keras_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import \
get_keras_tpc_latest

# Keras: TPC versioning
tpc_models_dict = {'v1': get_keras_tpc_v1,
LATEST: get_keras_tpc_latest}
elif fw_name == PYTORCH:
###############################
# Build Pytorch TPC models
###############################
if FOUND_TORCH:
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import \
get_pytorch_tpc as get_pytorch_tpc_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import \
get_pytorch_tpc_latest

# Pytorch: TPC versioning
tpc_models_dict = {'v1': get_pytorch_tpc_v1,
LATEST: get_pytorch_tpc_latest}

tpc_dict = {TENSORFLOW: keras_tpc_models_dict,
PYTORCH: pytorch_tpc_models_dict}

if tpc_models_dict is not None:
return tpc_models_dict
else:
Logger.critical(f'Framework {fw_name} is not supported in imx500 or the relevant packages are not '
f'installed. Please make sure the relevant packages are installed when using MCT for optimizing'
f' a {fw_name} model. For Tensorflow, please install tensorflow. For PyTorch, please install '
f'torch.') # pragma: no cover

0 comments on commit 7cd7690

Please sign in to comment.