Skip to content

Commit

Permalink
Solve format issue of loading pre-trained model weights
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-gao-GY committed Dec 1, 2023
1 parent 2a6d15c commit b04c1b7
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions baselines/fedvssl/fedvssl/finetune_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ def args_parser():

# Conversion of the format of pre-trained SSL model from .npz files to .pth format.
params = np.load(args.pretrained_model_path, allow_pickle=True)
try:

if params["arr_0"].shape == ():
# For the cases where weights are stored as Parameters
params = params["arr_0"].item()
except:
# For the cases where the weights are stored as NumPy arrays instead of parameters
params = parameters_to_ndarrays(params)
else:
# For the cases where weights are stored as NumPy arrays
params = [
np.array(v) for v in list(params["arr_0"])
]
try:
params = parameters_to_ndarrays(params)
except:
pass

params_dict = zip(model.state_dict().keys(), params)
state_dict = {
"state_dict": OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})
Expand Down

0 comments on commit b04c1b7

Please sign in to comment.