Skip to content

Commit

Permalink
fix: adopt latest state_dict processing
Browse files Browse the repository at this point in the history
  • Loading branch information
ruixin31 committed Jun 12, 2024
1 parent c65b430 commit b47d01e
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions open_lm/utils/transformers/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
checkpoint_path = kwargs["config"].checkpoint_file
checkpoint = torch.load(checkpoint_path)

state_dict = checkpoint["state_dict"]
state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()}
state_dict = {f"model.{x}": y for x, y in state_dict.items()}

return super().from_pretrained(None, state_dict=state_dict, **kwargs)
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith("module"):
sd = {k[len("module.") :]: v for k, v in sd.items()}
if "_orig_mod" in next(iter(sd.items()))[0]:
sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
sd = {f"model.{x}": y for x, y in sd.items()}

return super().from_pretrained(None, state_dict=sd, **kwargs)
else:
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)

Expand Down

0 comments on commit b47d01e

Please sign in to comment.