From b04c1b70c51a3f56c2d9c5d89e78a02d10c56db6 Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Fri, 1 Dec 2023 11:16:36 +0000 Subject: [PATCH] Solve format issue of loading pre-trained model weights --- baselines/fedvssl/fedvssl/finetune_preprocess.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/baselines/fedvssl/fedvssl/finetune_preprocess.py b/baselines/fedvssl/fedvssl/finetune_preprocess.py index 1f40c2b79a3e..415324c49f7e 100644 --- a/baselines/fedvssl/fedvssl/finetune_preprocess.py +++ b/baselines/fedvssl/fedvssl/finetune_preprocess.py @@ -43,17 +43,17 @@ def args_parser(): # Conversion of the format of pre-trained SSL model from .npz files to .pth format. params = np.load(args.pretrained_model_path, allow_pickle=True) -try: + +if params["arr_0"].shape == (): + # For the cases where weights are stored as Parameters params = params["arr_0"].item() -except: - # For the cases where the weights are stored as NumPy arrays instead of parameters + params = parameters_to_ndarrays(params) +else: + # For the cases where weights are stored as NumPy arrays params = [ np.array(v) for v in list(params["arr_0"]) ] -try: - params = parameters_to_ndarrays(params) -except: - pass + params_dict = zip(model.state_dict().keys(), params) state_dict = { "state_dict": OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})