diff --git a/model_compression_toolkit/trainable_infrastructure/pytorch/util.py b/model_compression_toolkit/trainable_infrastructure/pytorch/util.py index fec5062e5..cca32abd8 100644 --- a/model_compression_toolkit/trainable_infrastructure/pytorch/util.py +++ b/model_compression_toolkit/trainable_infrastructure/pytorch/util.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from functools import cache +from functools import lru_cache from typing import Callable from tqdm import tqdm -@cache +@lru_cache def get_total_grad_steps(representative_data_gen: Callable) -> int: # dry run on the representative dataset to count number of batches num_batches = 0