Skip to content

Commit

Permalink
make conversion auto_map compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Dec 11, 2023
1 parent 6462c3b commit 4609f13
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 18 deletions.
48 changes: 30 additions & 18 deletions hf_olmo/convert_olmo_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import os
import shutil

import torch

from hf_olmo.configuration_olmo import OLMoConfig
from hf_olmo.modeling_olmo import OLMoForCausalLM
from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast
from olmo import ModelConfig, Tokenizer

Expand All @@ -12,11 +15,10 @@

def write_config(checkpoint_dir: str):
# save config as HF config
from cached_path import cached_path

logger.info(f"Loading checkpoint from {checkpoint_dir}")

config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml"))
config_path = os.path.join(checkpoint_dir, "config.yaml")
model_config = ModelConfig.load(config_path, key="model")
config_kwargs = model_config.asdict()
config_kwargs["use_cache"] = True
Expand All @@ -26,15 +28,19 @@ def write_config(checkpoint_dir: str):
config.save_pretrained(checkpoint_dir)


def write_model(checkpoint_dir: str, soft_link: bool = True):
if soft_link:
try:
os.symlink("model.pt", os.path.join(checkpoint_dir, "pytorch_model.bin"))
except FileExistsError:
pass
else:
if not os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")):
os.rename(os.path.join(checkpoint_dir, "model.pt"), os.path.join(checkpoint_dir, "pytorch_model.bin"))
def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
# For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly.
# So, we explicitly store the model with the expected prefix.

old_model_path = os.path.join(checkpoint_dir, "model.pt")
new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin")

state_dict = torch.load(old_model_path)
new_state_dict = {f"{OLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()}
torch.save(new_state_dict, new_model_path)

if ignore_olmo_compatibility:
os.remove(old_model_path)


def write_tokenizer(checkpoint_dir: str):
Expand All @@ -52,6 +58,16 @@ def write_tokenizer(checkpoint_dir: str):
tokenizer.save_pretrained(checkpoint_dir)


def convert_checkpoint(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
write_config(checkpoint_dir)
write_model(checkpoint_dir, ignore_olmo_compatibility=ignore_olmo_compatibility)
write_tokenizer(checkpoint_dir)

# Cannot remove it before writing the tokenizer
if ignore_olmo_compatibility:
os.remove(os.path.join(checkpoint_dir, "config.yaml"))


def download_remote_checkpoint_and_convert_to_hf(checkpoint_dir: str, local_dir: str):
from cached_path import cached_path

Expand All @@ -71,9 +87,7 @@ def download_remote_checkpoint_and_convert_to_hf(checkpoint_dir: str, local_dir:
else:
logger.info(f"File already present at {final_location}")

write_config(local_model_path)
write_model(local_model_path, soft_link=False)
write_tokenizer(local_model_path)
convert_checkpoint(local_model_path)
return local_model_path


Expand All @@ -91,13 +105,11 @@ def main():
"--ignore-olmo-compatibility",
action="store_true",
help="Ignore compatibility with the olmo codebase. "
"This will rename model.pt --> pytorch_model.bin instead of creating a symlink.",
"This will remove files that are needed specifically for olmo codebase, eg. config.yaml, etc.",
)

args = parser.parse_args()
write_config(checkpoint_dir=args.checkpoint_dir)
write_model(checkpoint_dir=args.checkpoint_dir, soft_link=not args.ignore_olmo_compatibility)
write_tokenizer(checkpoint_dir=args.checkpoint_dir)
convert_checkpoint(args.checkpoint_dir, args.ignore_olmo_compatibility)


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ class OLMoForCausalLM(PreTrainedModel):

config_class = OLMoConfig
base_model_prefix = "model"
_no_split_modules = ["OLMoBlock"]

def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None):
super().__init__(config)

if not model:
model_config = create_model_config_from_pretrained_config(config)
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
model_config.init_device = "cpu"
self.model = Olmo(model_config, init_params=True)
else:
self.model = model
Expand Down
1 change: 0 additions & 1 deletion test_fixtures/test-olmo-model/pytorch_model.bin

This file was deleted.

Binary file added test_fixtures/test-olmo-model/pytorch_model.bin
Binary file not shown.

0 comments on commit 4609f13

Please sign in to comment.