diff --git a/baselines/fedvssl/fedvssl/finetune_preprocess.py b/baselines/fedvssl/fedvssl/finetune_preprocess.py index 1f40c2b79a3e..415324c49f7e 100644 --- a/baselines/fedvssl/fedvssl/finetune_preprocess.py +++ b/baselines/fedvssl/fedvssl/finetune_preprocess.py @@ -43,17 +43,17 @@ def args_parser(): # Conversion of the format of pre-trained SSL model from .npz files to .pth format. params = np.load(args.pretrained_model_path, allow_pickle=True) -try: + +if params["arr_0"].shape == (): + # For the cases where weights are stored as Parameters params = params["arr_0"].item() -except: - # For the cases where the weights are stored as NumPy arrays instead of parameters + params = parameters_to_ndarrays(params) +else: + # For the cases where weights are stored as NumPy arrays params = [ np.array(v) for v in list(params["arr_0"]) ] -try: - params = parameters_to_ndarrays(params) -except: - pass + params_dict = zip(model.state_dict().keys(), params) state_dict = { "state_dict": OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})