From 963eba727698ccdc14a23d71ac906d6054dce8e9 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 26 Oct 2023 10:41:24 +0300 Subject: [PATCH] Fixed issue with torch 1.12 where _scale_fn_ref is missing in CyclicLR --- src/super_gradients/training/utils/checkpoint_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index c524462c9e..afbf098eee 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -1643,5 +1643,7 @@ def get_scheduler_state(scheduler) -> Dict[str, Tensor]: state = scheduler.state_dict() if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0): - del state["_scale_fn_ref"] + # A check is needed since torch 1.12 does not have the _scale_fn_ref attribute, while other versions do + if "_scale_fn_ref" in state: + del state["_scale_fn_ref"] return state