Skip to content

Commit

Permalink
Check for TF version smaller than 2.16 (#1154)
Browse files Browse the repository at this point in the history
* Check tf version is smaller than 2.16

* Move FOUND_PKG variables from constants file
  • Loading branch information
lapid92 authored Aug 12, 2024
1 parent 002568b commit 77406ae
Show file tree
Hide file tree
Showing 52 changed files with 152 additions and 92 deletions.
7 changes: 0 additions & 7 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@
# limitations under the License.
# ==============================================================================

import importlib

# Supported frameworks in MCT:
TENSORFLOW = 'tensorflow'
PYTORCH = 'pytorch'
FOUND_TF = importlib.util.find_spec(TENSORFLOW) is not None
FOUND_TORCH = importlib.util.find_spec("torch") is not None
FOUND_TORCHVISION = importlib.util.find_spec("torchvision") is not None
FOUND_ONNX = importlib.util.find_spec("onnx") is not None
FOUND_ONNXRUNTIME = importlib.util.find_spec("onnxruntime") is not None
FOUND_SONY_CUSTOM_LAYERS = importlib.util.find_spec('sony_custom_layers') is not None

# Metadata fields
MCT_VERSION = 'mct_version'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Any, Tuple, Type, List, Union

from model_compression_toolkit.constants import FOUND_TF
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.graph.base_node import BaseNode
import numpy as np

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
from model_compression_toolkit.constants import FOUND_TF
from model_compression_toolkit.verify_packages import FOUND_TF

if FOUND_TF:
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
Expand Down Expand Up @@ -89,5 +89,6 @@ def keras_resource_utilization_data(in_model: Model,
# If tensorflow is not installed,
# we raise an exception when trying to use this function.
def keras_resource_utilization_data(*args, **kwargs):
Logger.critical("Tensorflow must be installed to use keras_resource_utilization_data. "
"The 'tensorflow' package is missing.") # pragma: no cover
Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
"keras_resource_utilization_data. The 'tensorflow' package is either not installed or is "
"installed with a version higher than 2.15.") # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
from model_compression_toolkit.constants import FOUND_TORCH
from model_compression_toolkit.verify_packages import FOUND_TORCH

if FOUND_TORCH:
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/data_generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TF, FOUND_TORCHVISION
from model_compression_toolkit.verify_packages import FOUND_TORCHVISION, FOUND_TORCH, FOUND_TF
from model_compression_toolkit.data_generation.common.data_generation_config import DataGenerationConfig
from model_compression_toolkit.data_generation.common.enums import ImageGranularity, DataInitType, SchedulerType, BNLayerWeightingType, OutputLossType, BatchNormAlignemntLossType, ImagePipelineType, ImageNormalizationType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Callable, Tuple, List, Dict, Union
from tqdm import tqdm

from model_compression_toolkit.constants import FOUND_TF
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.data_generation.common.constants import DEFAULT_N_ITER, DEFAULT_DATA_GEN_BS
from model_compression_toolkit.data_generation.common.data_generation import get_data_generation_classes
from model_compression_toolkit.data_generation.common.image_pipeline import image_normalization_dict
Expand Down Expand Up @@ -349,8 +349,12 @@ def keras_compute_grads(imgs_to_optimize: tf.Tensor,
else:
def get_keras_data_generation_config(*args, **kwargs):
Logger.critical(
"Tensorflow must be installed to use get_tensorflow_data_generation_config. The 'tensorflow' package is missing.") # pragma: no cover
"Tensorflow must be installed with a version of 2.15 or lower to use "
"get_tensorflow_data_generation_config. The 'tensorflow' package is missing or is installed with a "
"version higher than 2.15.") # pragma: no cover


def keras_data_generation_experimental(*args, **kwargs):
Logger.critical("Tensorflow must be installed to use tensorflow_data_generation_experimental. The 'tensorflow' package is missing.") # pragma: no cover
Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
"tensorflow_data_generation_experimental. The 'tensorflow' package is missing or is installed "
"with a version higher than 2.15.") # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from tqdm import tqdm

from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TORCHVISION
from model_compression_toolkit.verify_packages import FOUND_TORCHVISION, FOUND_TORCH
from model_compression_toolkit.core.pytorch.utils import set_model
from model_compression_toolkit.data_generation.common.constants import DEFAULT_N_ITER, DEFAULT_DATA_GEN_BS
from model_compression_toolkit.data_generation.common.data_generation import get_data_generation_classes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
from typing import Callable, Dict

from model_compression_toolkit.constants import FOUND_TF
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.exporter.model_exporter.fw_agonstic.quantization_format import QuantizationFormat
from model_compression_toolkit.logger import Logger

Expand Down Expand Up @@ -101,5 +101,6 @@ def keras_export_model(model: keras.models.Model,
return exporter.get_custom_objects()
else:
def keras_export_model(*args, **kwargs):
Logger.critical("Tensorflow must be installed to use keras_export_model. "
"The 'tensorflow' package is missing.") # pragma: no cover
Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use keras_export_model."
"The 'tensorflow' package is missing or is installed "
"with a version higher than 2.15.") # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn

from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
from model_compression_toolkit.constants import FOUND_ONNX
from model_compression_toolkit.verify_packages import FOUND_ONNX
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
from typing import Callable

from model_compression_toolkit.constants import FOUND_TORCH
from model_compression_toolkit.verify_packages import FOUND_TORCH
from model_compression_toolkit.exporter.model_exporter.fw_agonstic.quantization_format import QuantizationFormat
from model_compression_toolkit.exporter.model_exporter.pytorch.export_serialization_format import \
PytorchExportSerializationFormat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Tuple, Callable
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.constants import FOUND_TF
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.logger import Logger
import model_compression_toolkit.core as C
Expand Down Expand Up @@ -101,5 +101,6 @@ def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, Use
return exportable_model, user_info
else:
def get_exportable_keras_model(*args, **kwargs): # pragma: no cover
Logger.critical("Tensorflow must be installed to use get_exportable_keras_model. "
"The 'tensorflow' package is missing.") # pragma: no cover
Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
"get_exportable_keras_model. The 'tensorflow' package is missing or is installed with a "
"version higher than 2.15.") # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any

from mct_quantizers import BaseInferableQuantizer, KerasActivationQuantizationHolder
from model_compression_toolkit.constants import FOUND_TF
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.logger import Logger

if FOUND_TF:
Expand Down Expand Up @@ -76,5 +76,6 @@ def is_keras_layer_exportable(layer: Any) -> bool:
return True
else:
def is_keras_layer_exportable(*args, **kwargs): # pragma: no cover
Logger.critical("Tensorflow must be installed to use is_keras_layer_exportable. "
"The 'tensorflow' package is missing.") # pragma: no cover
Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
"is_keras_layer_exportable. The 'tensorflow' package is missing or is installed with a "
"version higher than 2.15.") # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Union, Callable
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.constants import FOUND_TORCH
from model_compression_toolkit.verify_packages import FOUND_TORCH
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common import BaseNode
import model_compression_toolkit.core as C
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from typing import Any

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import FOUND_TORCH

from model_compression_toolkit.verify_packages import FOUND_TORCH

if FOUND_TORCH:
import torch.nn as nn
Expand Down
13 changes: 8 additions & 5 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
Expand Down Expand Up @@ -251,10 +252,12 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da
# If tensorflow is not installed,
# we raise an exception when trying to use these functions.
def get_keras_gptq_config(*args, **kwargs):
Logger.critical("Tensorflow must be installed to use get_keras_gptq_config. "
"The 'tensorflow' package is missing.") # pragma: no cover
Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
"get_keras_gptq_config. The 'tensorflow' package is missing or is "
"installed with a version higher than 2.15.") # pragma: no cover


def keras_gradient_post_training_quantization(*args, **kwargs):
Logger.critical("Tensorflow must be installed to use keras_gradient_post_training_quantization. "
"The 'tensorflow' package is missing.") # pragma: no cover
Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
"keras_gradient_post_training_quantization. The 'tensorflow' package is missing or is "
"installed with a version higher than 2.15.") # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Union, Dict, List

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import FOUND_TF
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS

from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
Expand Down Expand Up @@ -105,5 +105,6 @@ def get_quant_config(self):
else:
class BaseKerasGPTQTrainableQuantizer: # pragma: no cover
def __init__(self, *args, **kwargs):
Logger.critical("Tensorflow must be installed to use BaseKerasGPTQTrainableQuantizer. "
"The 'tensorflow' package is missing.") # pragma: no cover
Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
"BaseKerasGPTQTrainableQuantizer. The 'tensorflow' package is missing or is "
"installed with a version higher than 2.15.") # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from typing import Callable
from model_compression_toolkit.core import common
from model_compression_toolkit.constants import FOUND_TORCH, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.verify_packages import FOUND_TORCH
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
from model_compression_toolkit.logger import Logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Union, Dict

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import FOUND_TORCH
from model_compression_toolkit.verify_packages import FOUND_TORCH
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS

from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
Expand Down
8 changes: 5 additions & 3 deletions model_compression_toolkit/pruning/keras/pruning_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from typing import Callable, Tuple

from model_compression_toolkit import get_target_platform_capabilities
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.common.pruning.pruner import Pruner
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
Expand Down Expand Up @@ -149,5 +150,6 @@ def keras_pruning_experimental(model: Model,
# If tensorflow is not installed,
# we raise an exception when trying to use these functions.
def keras_pruning_experimental(*args, **kwargs):
Logger.critical("Tensorflow must be installed to use keras_pruning_experimental. "
"The 'tensorflow' package is missing.") # pragma: no cover
Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
"keras_pruning_experimental. The 'tensorflow' package is missing or is "
"installed with a version higher than 2.15.") # pragma: no cover
3 changes: 2 additions & 1 deletion model_compression_toolkit/pruning/pytorch/pruning_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from typing import Callable, Tuple
from model_compression_toolkit import get_target_platform_capabilities
from model_compression_toolkit.constants import FOUND_TORCH, PYTORCH
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.verify_packages import FOUND_TORCH
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.common.pruning.pruner import Pruner
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
Expand Down
8 changes: 5 additions & 3 deletions model_compression_toolkit/ptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfig
Expand Down Expand Up @@ -178,5 +179,6 @@ def keras_post_training_quantization(in_model: Model,
# If tensorflow is not installed,
# we raise an exception when trying to use these functions.
def keras_post_training_quantization(*args, **kwargs):
Logger.critical("Tensorflow must be installed to use keras_post_training_quantization. "
"The 'tensorflow' package is missing.") # pragma: no cover
Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
"keras_post_training_quantization. The 'tensorflow' package is missing or is "
"installed with a version higher than 2.15.") # pragma: no cover
3 changes: 2 additions & 1 deletion model_compression_toolkit/ptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import PYTORCH, FOUND_TORCH
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.verify_packages import FOUND_TORCH
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core import CoreConfig
Expand Down
Loading

0 comments on commit 77406ae

Please sign in to comment.