Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

setup_model.py and setup_optimizer.py moved to seprate files and adde… #264

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 1 addition & 211 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from copy import deepcopy
from pathlib import Path
import argparse
import math
import os
import re
import subprocess
import time

# Third Party
from accelerate import Accelerator
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.runtime.zero.utils import ZeRORuntimeException

# pylint: disable=no-name-in-module
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from tqdm import tqdm
from transformers import AutoModelForCausalLM, get_scheduler
import torch
import torch.distributed

Expand All @@ -34,18 +29,12 @@
from instructlab.training.multipack_sampler import (
find_packing_max_batch_len_and_grad_accum,
)
from instructlab.training.setup_accelerator import setup_accelerator
from instructlab.training.setup_model import setup_model
from instructlab.training.token_dataset import setup_dataloader, setup_dataset
from instructlab.training.tokenizer_utils import setup_tokenizer
from instructlab.training.utils import (
StreamablePopen,
add_noisy_embeddings,
apply_gradient_checkpointing,
convert_loss_to_reduce_sum,
ensure_loadable_granite_checkpoint,
get_projection_layer_names,
load_latest_full_state,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
retrieve_chat_template,
save_checkpoint,
Expand All @@ -56,205 +45,6 @@
import instructlab.training.data_process as dp


def setup_optimizer(args, model):
if args.distributed_training_framework == DistributedBackend.FSDP.value:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(0.9, 0.95),
weight_decay=0.0,
)
elif args.distributed_training_framework == DistributedBackend.DEEPSPEED.value:
# need to use this only when the CPU offload optimizer is enabled
if args.cpu_offload_optimizer:
print(
"\033[33m!!! CPU offload optimizer enabled, using DeepSpeedCPUAdam !!!\033[0m"
)
optimizer = DeepSpeedCPUAdam(
model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
)
else:
optimizer = FusedAdam(
model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
)
else:
raise ValueError(
f"Sharding framework {args.distributed_training_framework} is not supported."
)
return optimizer


def setup_model(args, tokenizer, train_loader, grad_accum):
bnb_config = None
if args.lora_r > 0 and args.lora_quant_bits == 4:
# Third Party
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16, # if not set will throw a warning about slow speeds when training
)

base_model_args = {
"pretrained_model_name_or_path": args.model_name_or_path,
"torch_dtype": torch.bfloat16,
"quantization_config": bnb_config,
}
if not args.disable_flash_attn:
base_model_args["attn_implementation"] = "flash_attention_2"
elif args.is_granite:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)

if args.is_granite:
with ensure_loadable_granite_checkpoint(
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
model = GPTDolomiteForCausalLM.from_pretrained(
**base_model_args,
use_padding_free_transformer=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(**base_model_args)

if len(tokenizer) > model.config.vocab_size:
print(
f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
)
model.resize_token_embeddings(
int(8 * math.ceil(len(tokenizer) / 8.0))
) # make the vocab size multiple of 8 for sharding the embedding layer.

# Fix any discrepancy between model and tokenizer
if (
model.config.pad_token_id is not None
and tokenizer.pad_token_id is not None
and model.config.pad_token_id != tokenizer.pad_token_id
):
print(
f"WARNING: There is a mismatch between pad token id of model ({model.config.pad_token_id}) and tokenizer({tokenizer.pad_token_id}). Fixing model pad token id to be same as tokenizer's pad token id"
)
model.config.pad_token_id = tokenizer.pad_token_id
if (
model.config.bos_token_id is not None
and tokenizer.bos_token_id is not None
and model.config.bos_token_id != tokenizer.bos_token_id
):
print(
f"WARNING: There is a mismatch between bos token id of model({model.config.bos_token_id}) and tokenizer({tokenizer.bos_token_id}). Fixing model bos token id to be same as tokenizer's bos token id"
)
model.config.bos_token_id = tokenizer.bos_token_id
if (
model.config.eos_token_id is not None
and tokenizer.eos_token_id
and model.config.eos_token_id != tokenizer.eos_token_id
):
print(
f"WARNING: There is a mismatch between eos token id of model({model.config.eos_token_id}) and tokenizer({tokenizer.eos_token_id}). Fixing model eos token id to be same as tokenizer's eos token id"
)
model.config.eos_token_id = tokenizer.eos_token_id

assert model.__class__.__name__ in [
"MistralForCausalLM",
"GPTDolomiteForCausalLM",
"LlamaForCausalLM",
"Starcoder2ForCausalLM",
"GemmaForCausalLM",
"MixtralForCausalLM",
], f"Model class name: {model.__class__.__name__} is not supported."

model = convert_loss_to_reduce_sum(model, is_granite=args.is_granite)
model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha)

# handling of gradient checkpointing
# it is handled differently for lora and full
# - with the exception of granite, which handles it
# in the later stanza
if args.lora_r > 0:
# if lora
# Third Party
from peft import LoraConfig

# ensure we select only the modules that exist in the model
proj_layers = get_projection_layer_names(model)
if not args.lora_target_modules:
print(
f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
)
if not proj_layers:
raise RuntimeError("could not find any projection layers in the model")
args.__dict__["lora_target_modules"] = proj_layers
else:
# when the user specifies the module, we should verify that they align with what's in the model
lora_target_modules_set = set(args.lora_target_modules)
diff = lora_target_modules_set - set(proj_layers)
layers_to_target = lora_target_modules_set - diff
if len(diff) == len(args.lora_target_modules):
raise ValueError(
f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
)
if diff:
print(
f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
)
args.__dict__["lora_target_modules"] = list(layers_to_target)

peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules,
)
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.is_granite
)

elif not args.is_granite:
model.gradient_checkpointing_enable()

# granite gradient checkpointing is handled uniformly
# for both lora and full here
if args.is_granite:
block_name = model._no_split_modules[0]
apply_gradient_checkpointing(
model,
block_name=block_name,
use_reentrant=True, # this should be the HF default mode
)

if args.lora_r > 0:

def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

accelerator = setup_accelerator(args, model, grad_accum)
if args.distributed_training_framework == DistributedBackend.FSDP.value:
model = accelerator.prepare(model)
optimizer = setup_optimizer(args, model)

lr_scheduler = get_scheduler(
name=args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.num_epochs * len(train_loader) // grad_accum,
)
model, optimizer, _, lr_scheduler = accelerator.prepare(
model,
optimizer,
deepcopy(train_loader),
lr_scheduler,
)
return model, lr_scheduler, optimizer, accelerator


# this function is to check if the checkpoint provided can be resumed
def maybe_resume_training(args, model):
local_rank = int(os.environ["LOCAL_RANK"])
Expand Down
Loading