diff --git a/hf_olmo/convert_olmo_to_hf.py b/hf_olmo/convert_olmo_to_hf.py index b4600794d..23d018022 100644 --- a/hf_olmo/convert_olmo_to_hf.py +++ b/hf_olmo/convert_olmo_to_hf.py @@ -3,7 +3,10 @@ import os import shutil +import torch + from hf_olmo.configuration_olmo import OLMoConfig +from hf_olmo.modeling_olmo import OLMoForCausalLM from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast from olmo import ModelConfig, Tokenizer @@ -12,11 +15,10 @@ def write_config(checkpoint_dir: str): # save config as HF config - from cached_path import cached_path logger.info(f"Loading checkpoint from {checkpoint_dir}") - config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml")) + config_path = os.path.join(checkpoint_dir, "config.yaml") model_config = ModelConfig.load(config_path, key="model") config_kwargs = model_config.asdict() config_kwargs["use_cache"] = True @@ -26,15 +28,19 @@ def write_config(checkpoint_dir: str): config.save_pretrained(checkpoint_dir) -def write_model(checkpoint_dir: str, soft_link: bool = True): - if soft_link: - try: - os.symlink("model.pt", os.path.join(checkpoint_dir, "pytorch_model.bin")) - except FileExistsError: - pass - else: - if not os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")): - os.rename(os.path.join(checkpoint_dir, "model.pt"), os.path.join(checkpoint_dir, "pytorch_model.bin")) +def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): + # For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly. + # So, we explicitly store the model with the expected prefix. + + old_model_path = os.path.join(checkpoint_dir, "model.pt") + new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin") + + state_dict = torch.load(old_model_path) + new_state_dict = {f"{OLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()} + torch.save(new_state_dict, new_model_path) + + if ignore_olmo_compatibility: + os.remove(old_model_path) def write_tokenizer(checkpoint_dir: str): @@ -52,6 +58,16 @@ def write_tokenizer(checkpoint_dir: str): tokenizer.save_pretrained(checkpoint_dir) +def convert_checkpoint(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): + write_config(checkpoint_dir) + write_model(checkpoint_dir, ignore_olmo_compatibility=ignore_olmo_compatibility) + write_tokenizer(checkpoint_dir) + + # Cannot remove it before writing the tokenizer + if ignore_olmo_compatibility: + os.remove(os.path.join(checkpoint_dir, "config.yaml")) + + def download_remote_checkpoint_and_convert_to_hf(checkpoint_dir: str, local_dir: str): from cached_path import cached_path @@ -71,9 +87,7 @@ def download_remote_checkpoint_and_convert_to_hf(checkpoint_dir: str, local_dir: else: logger.info(f"File already present at {final_location}") - write_config(local_model_path) - write_model(local_model_path, soft_link=False) - write_tokenizer(local_model_path) + convert_checkpoint(local_model_path) return local_model_path @@ -91,13 +105,11 @@ def main(): "--ignore-olmo-compatibility", action="store_true", help="Ignore compatibility with the olmo codebase. " - "This will rename model.pt --> pytorch_model.bin instead of creating a symlink.", + "This will remove files that are needed specifically for olmo codebase, eg. config.yaml, etc.", ) args = parser.parse_args() - write_config(checkpoint_dir=args.checkpoint_dir) - write_model(checkpoint_dir=args.checkpoint_dir, soft_link=not args.ignore_olmo_compatibility) - write_tokenizer(checkpoint_dir=args.checkpoint_dir) + convert_checkpoint(args.checkpoint_dir, args.ignore_olmo_compatibility) if __name__ == "__main__": diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index 5abe50325..f011567a7 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -31,12 +31,15 @@ class OLMoForCausalLM(PreTrainedModel): config_class = OLMoConfig base_model_prefix = "model" + _no_split_modules = ["OLMoBlock"] def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None): super().__init__(config) if not model: model_config = create_model_config_from_pretrained_config(config) + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + model_config.init_device = "cpu" self.model = Olmo(model_config, init_params=True) else: self.model = model diff --git a/test_fixtures/test-olmo-model/pytorch_model.bin b/test_fixtures/test-olmo-model/pytorch_model.bin deleted file mode 120000 index 6702297cd..000000000 --- a/test_fixtures/test-olmo-model/pytorch_model.bin +++ /dev/null @@ -1 +0,0 @@ -model.pt \ No newline at end of file diff --git a/test_fixtures/test-olmo-model/pytorch_model.bin b/test_fixtures/test-olmo-model/pytorch_model.bin new file mode 100644 index 000000000..fa31c85b2 Binary files /dev/null and b/test_fixtures/test-olmo-model/pytorch_model.bin differ diff --git a/tests/hf_olmo/modeling_olmo_test.py b/tests/hf_olmo/modeling_olmo_test.py index 5ae19f6a7..fda1bd715 100644 --- a/tests/hf_olmo/modeling_olmo_test.py +++ b/tests/hf_olmo/modeling_olmo_test.py @@ -1,5 +1,6 @@ import tempfile +import pytest import torch from olmo.model import Olmo @@ -42,3 +43,14 @@ def test_save_pretrained(model_path: str): saved_hf_output = saved_hf_model(input_tensor) torch.testing.assert_allclose(saved_hf_output.logits, hf_output.logits) + + +@pytest.mark.gpu +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA devices") +def test_auto_device_map_load(model_path: str): + from transformers import AutoModelForCausalLM + + from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast # noqa: F401 + + hf_model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") + assert hf_model.device.type == "cuda"