Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support lama65b single a100 finetuning #146

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion large_language_models/alpaca-qlora/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- convert the weight dtype of quant backbone from torch.int32 to torch.int8: `python3 convert_pack32topack8.py /path/to/quant-backbone-pack32 /path/to/output-quant-backbone-pack8`

#### Training LLaMA-7b on single 2080ti
- `python3 finetune.py`
- `python3 finetune.py decapoda-research/llama-7b-hf --int4_backbone /path/to/llama7b-pack8`

#### Training LLaMA-65b on 8*2080ti with Pipeline Parallelism(PP)
- `python3 finetune_pp.py decapoda-research/llama-65b-hf /path/to/llama65b-pack8 --chunks 16 --pp_checkpoint except_last --micro_batch_size 32`
Expand Down
278 changes: 155 additions & 123 deletions large_language_models/alpaca-qlora/finetune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os

import argparse
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
Expand All @@ -21,80 +21,72 @@
from qlora import get_peft_qmodel


# optimized for RTX 4090. for larger GPUs, increase some of these?
DEBUG = False
QUANT = True
if DEBUG:
MICRO_BATCH_SIZE = 2
BATCH_SIZE = 2
else:
MICRO_BATCH_SIZE = 16 # this could actually be 5 but i like powers of 2
BATCH_SIZE = 128
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
EPOCHS = 3 # we don't need 3 tbh
LEARNING_RATE = 3e-4 # the Karpathy constant
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
TRAIN_SET_SIZE = None
VAL_SET_SIZE = 2000
model_arch = "llama-7b"
model_cachedir = "./caches/{}/".format(model_arch)
if DEBUG:
TRAIN_SET_SIZE = 2000
VAL_SET_SIZE = 100

if QUANT:
config = transformers.AutoConfig.from_pretrained(
os.path.join(model_cachedir, "config.json")
)
model = load_qllama(
config, os.path.join(model_cachedir, "{}_4w_pack8.pth.tar".format(model_arch))
)
model.is_loaded_in_8bit = True # hack for gradient-checkpoint
model = prepare_model_for_int8_training(model)
model.is_loaded_in_8bit = False
model.seq_len = 2048
peft_func = get_peft_qmodel
else:
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
load_in_8bit=True,
device_map="auto",
def main(args):

# optimized for RTX 4090. for larger GPUs, increase some of these?
MICRO_BATCH_SIZE = args.micro_batch_size
BATCH_SIZE = args.batch_size
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
EPOCHS = args.epochs # we don't need 3 tbh
LEARNING_RATE = args.lr # the Karpathy constant
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
TRAIN_SET_SIZE = None
VAL_SET_SIZE = 2000

if args.int4_backbone:
config = transformers.AutoConfig.from_pretrained(
args.model
)
model = load_qllama(
config, args.int4_backbone
)
model.is_loaded_in_8bit = True # hack for gradient-checkpoint
model = prepare_model_for_int8_training(model)
model.is_loaded_in_8bit = False
model.seq_len = 2048
peft_func = get_peft_qmodel
else:
model = LlamaForCausalLM.from_pretrained(
args.model,
load_in_8bit=True,
device_map="auto",
)
model = prepare_model_for_int8_training(model)
peft_func = get_peft_model

tokenizer = LlamaTokenizer.from_pretrained(args.model, add_eos_token=True)

config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=["q_proj", "v_proj"],
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="QUANT_CAUSAL_LM" if args.int4_backbone else "CAUSAL_LM",
)
model = prepare_model_for_int8_training(model)
peft_func = get_peft_model

# tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf", add_eos_token=True)
tokenizer = LlamaTokenizer.from_pretrained(
os.path.join(model_cachedir, "tokenizer"), add_eos_token=True
)
model = peft_func(model, config)

config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=["q_proj", "v_proj"],
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="QUANT_CAUSAL_LM" if QUANT else "CAUSAL_LM",
)
model = peft_func(model, config)
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference

tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
data = load_dataset("json", data_files="alpaca_data.json")
data = load_dataset("yahma/alpaca-cleaned")

train_val = data["train"].train_test_split(
train_size=TRAIN_SET_SIZE, test_size=VAL_SET_SIZE, shuffle=True, seed=42
)
train_data = train_val["train"]
val_data = train_val["test"]
train_val = data["train"].train_test_split(
train_size=TRAIN_SET_SIZE, test_size=VAL_SET_SIZE, shuffle=True, seed=42
)
train_data = train_val["train"]
val_data = train_val["test"]


def generate_prompt(data_point):
# sorry about the formatting disaster gotta move fast
if data_point["input"]:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
def generate_prompt(data_point):
# sorry about the formatting disaster gotta move fast
if data_point["input"]:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{data_point["instruction"]}
Expand All @@ -104,8 +96,8 @@ def generate_prompt(data_point):

### Response:
{data_point["output"]}"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{data_point["instruction"]}
Expand All @@ -114,61 +106,101 @@ def generate_prompt(data_point):
{data_point["output"]}"""


def tokenize(prompt):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
def tokenize(prompt):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)
return {
"input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1],
}


train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
val_data = val_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))

trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=MICRO_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
warmup_steps=args.warmup_steps,
num_train_epochs=EPOCHS,
learning_rate=LEARNING_RATE,
fp16=True,
logging_steps=10,
logging_dir=args.logging_dir,
logging_strategy="steps",
optim="adamw_torch",
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=200,
save_steps=200,
output_dir="lora-alpaca",
save_total_limit=10,
load_best_model_at_end=True,
weight_decay=args.weight_decay,
adam_beta1=args.adam_beta1,
adam_beta2=args.adam_beta2,
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
return {
"input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1],
}


train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
val_data = val_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))

trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=MICRO_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
warmup_steps=100,
num_train_epochs=EPOCHS,
learning_rate=LEARNING_RATE,
fp16=True,
# logging_steps=20,
logging_steps=1,
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=200,
save_steps=200,
output_dir="lora-alpaca",
save_total_limit=3,
load_best_model_at_end=True,
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
model.config.use_cache = False

old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))

# IMPORTANT! model.eval() -> model.train() enable requant 4-bit weights
model.eval()
model.train()
# IMPORTANT! model.eval() -> model.train() enable requant 4-bit weights
model.eval()
model.train()

trainer.train()
# res = trainer.evaluate()
trainer.train()
# res = trainer.evaluate()

model.save_pretrained("lora-alpaca")
model.save_pretrained("lora-alpaca")

print("\n If there's a warning about missing keys above, please disregard :)")
print("\n If there's a warning about missing keys above, please disregard :)")

if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("model", type=str, help="model name")
parser.add_argument("--int4_backbone", type=str, default="", help="path to 4bit checkpoint, using int4 backbone if provided")
parser.add_argument(
"--batch_size", type=int, default=128, help="Batch size for training."
)
parser.add_argument(
"--micro_batch_size", type=int, default=16, help="Batch size for training."
)
parser.add_argument(
"--logging_dir", type=str, default="runs/logs", help="dir for logging."
)
parser.add_argument(
"--epochs", type=int, default=3, help="epochs for training."
)
parser.add_argument(
"--lr", type=float, default=3e-4, help="learning rate"
)
parser.add_argument(
"--weight_decay", type=float, default=0.0, help="weight_decay for training."
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="adam_beta1 for training."
)
parser.add_argument(
"--adam_beta2", type=float, default=0.999, help="adam_beta2 for training."
)
parser.add_argument(
"--warmup_steps", type=int, default=100, help="adam_beta2 for training."
)
args = parser.parse_args()
main(args)
Loading