Skip to content

Commit

Permalink
Merge pull request #393 from allenai/hf-olmo-auto-map
Browse files Browse the repository at this point in the history
Make hf_olmo device_map compatible
  • Loading branch information
AkshitaB authored Dec 11, 2023
2 parents 6462c3b + fadd450 commit 364e21e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
48 changes: 30 additions & 18 deletions hf_olmo/convert_olmo_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import os
import shutil

import torch

from hf_olmo.configuration_olmo import OLMoConfig
from hf_olmo.modeling_olmo import OLMoForCausalLM
from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast
from olmo import ModelConfig, Tokenizer

Expand All @@ -12,11 +15,10 @@

def write_config(checkpoint_dir: str):
# save config as HF config
from cached_path import cached_path

logger.info(f"Loading checkpoint from {checkpoint_dir}")

config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml"))
config_path = os.path.join(checkpoint_dir, "config.yaml")
model_config = ModelConfig.load(config_path, key="model")
config_kwargs = model_config.asdict()
config_kwargs["use_cache"] = True
Expand All @@ -26,15 +28,19 @@ def write_config(checkpoint_dir: str):
config.save_pretrained(checkpoint_dir)


def write_model(checkpoint_dir: str, soft_link: bool = True):
if soft_link:
try:
os.symlink("model.pt", os.path.join(checkpoint_dir, "pytorch_model.bin"))
except FileExistsError:
pass
else:
if not os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")):
os.rename(os.path.join(checkpoint_dir, "model.pt"), os.path.join(checkpoint_dir, "pytorch_model.bin"))
def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
# For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly.
# So, we explicitly store the model with the expected prefix.

old_model_path = os.path.join(checkpoint_dir, "model.pt")
new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin")

state_dict = torch.load(old_model_path)
new_state_dict = {f"{OLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()}
torch.save(new_state_dict, new_model_path)

if ignore_olmo_compatibility:
os.remove(old_model_path)


def write_tokenizer(checkpoint_dir: str):
Expand All @@ -52,6 +58,16 @@ def write_tokenizer(checkpoint_dir: str):
tokenizer.save_pretrained(checkpoint_dir)


def convert_checkpoint(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
write_config(checkpoint_dir)
write_model(checkpoint_dir, ignore_olmo_compatibility=ignore_olmo_compatibility)
write_tokenizer(checkpoint_dir)

# Cannot remove it before writing the tokenizer
if ignore_olmo_compatibility:
os.remove(os.path.join(checkpoint_dir, "config.yaml"))


def download_remote_checkpoint_and_convert_to_hf(checkpoint_dir: str, local_dir: str):
from cached_path import cached_path

Expand All @@ -71,9 +87,7 @@ def download_remote_checkpoint_and_convert_to_hf(checkpoint_dir: str, local_dir:
else:
logger.info(f"File already present at {final_location}")

write_config(local_model_path)
write_model(local_model_path, soft_link=False)
write_tokenizer(local_model_path)
convert_checkpoint(local_model_path)
return local_model_path


Expand All @@ -91,13 +105,11 @@ def main():
"--ignore-olmo-compatibility",
action="store_true",
help="Ignore compatibility with the olmo codebase. "
"This will rename model.pt --> pytorch_model.bin instead of creating a symlink.",
"This will remove files that are needed specifically for olmo codebase, eg. config.yaml, etc.",
)

args = parser.parse_args()
write_config(checkpoint_dir=args.checkpoint_dir)
write_model(checkpoint_dir=args.checkpoint_dir, soft_link=not args.ignore_olmo_compatibility)
write_tokenizer(checkpoint_dir=args.checkpoint_dir)
convert_checkpoint(args.checkpoint_dir, args.ignore_olmo_compatibility)


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ class OLMoForCausalLM(PreTrainedModel):

config_class = OLMoConfig
base_model_prefix = "model"
_no_split_modules = ["OLMoBlock"]

def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None):
super().__init__(config)

if not model:
model_config = create_model_config_from_pretrained_config(config)
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
model_config.init_device = "cpu"
self.model = Olmo(model_config, init_params=True)
else:
self.model = model
Expand Down
1 change: 0 additions & 1 deletion test_fixtures/test-olmo-model/pytorch_model.bin

This file was deleted.

Binary file added test_fixtures/test-olmo-model/pytorch_model.bin
Binary file not shown.
12 changes: 12 additions & 0 deletions tests/hf_olmo/modeling_olmo_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tempfile

import pytest
import torch

from olmo.model import Olmo
Expand Down Expand Up @@ -42,3 +43,14 @@ def test_save_pretrained(model_path: str):
saved_hf_output = saved_hf_model(input_tensor)

torch.testing.assert_allclose(saved_hf_output.logits, hf_output.logits)


@pytest.mark.gpu
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA devices")
def test_auto_device_map_load(model_path: str):
from transformers import AutoModelForCausalLM

from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast # noqa: F401

hf_model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
assert hf_model.device.type == "cuda"

0 comments on commit 364e21e

Please sign in to comment.