Skip to content

Commit

Permalink
it works
Browse files Browse the repository at this point in the history
  • Loading branch information
rawsh authored Nov 26, 2024
1 parent a93c95a commit 271ae4a
Show file tree
Hide file tree
Showing 5 changed files with 1,153 additions and 72 deletions.
184 changes: 121 additions & 63 deletions mcts/train_policy_sft_metamath.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,114 @@
from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer
from transformers import TrainingArguments
import wandb
from datasets import load_dataset
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments
from datasets import load_dataset
from unsloth.chat_templates import get_chat_template

# Constants
SEED = 42
max_seq_length = 8192
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = False
load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False.

first = True
first_type1 = True
first_type2 = True

def format_answer(response):
global first
"""Extract answer from #### pattern and format response."""
global first_type1
global first_type2

# Split at #### and get everything before it
parts = response.split('####')
if len(parts) < 2:
return None


# combine the last two steps
steps = parts[0].strip().split("\n")
if len(steps) > 1:
steps[-2] = steps[-2] + f"\n{steps[-1]}"
steps = steps[:-1]
sol = "\n\n".join(steps)

if (first_type1):
print(response)
first_type1 = False

return sol
else:
return None

solution = "\n\n".join(parts[0].strip().split("\n"))
answer = parts[1].split('The answer is:')[0].strip()
answer = parts[1].split('The answer is:')
answer = answer[0].strip()
sol = f"{solution}\nThe answer is: {answer}"

if (first):
print(solution)
print(answer)
first = False

return f"{solution}\n\nThe final answer is: $\\boxed{{{answer}}}$"
if (first_type2):
print(response)
first_type2 = False

return sol

def train_sft():
# Load model and tokenizer
# Load base and instruct models
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "Qwen/Qwen2.5-0.5B",
model_name = "unsloth/Qwen2.5-0.5B",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)

# Set up chat template
tokenizer = get_chat_template(
tokenizer,
chat_template = "qwen-2.5",
model_instruct, tokenizer_instruct = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen2.5-0.5B-Instruct",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)

# Configure PEFT
# Transfer chat token embeddings from instruct to base model
base_embeddings = model.get_input_embeddings()
instruct_embeddings = model_instruct.get_input_embeddings()
chat_tokens = ["<|im_start|>", "<|im_end|>", "system", "assistant", "user"]
with torch.no_grad():
for token in chat_tokens:
try:
instruct_id = tokenizer_instruct.convert_tokens_to_ids(token)
base_id = tokenizer.convert_tokens_to_ids(token)
if instruct_id != tokenizer_instruct.unk_token_id and base_id != tokenizer.unk_token_id:
base_embeddings.weight[base_id] = instruct_embeddings.weight[instruct_id].clone()
print(f"Transferred embedding for token: {token}")
else:
print(f"Warning: Token {token} not found in one of the vocabularies")
except Exception as e:
print(f"Error transferring token {token}: {str(e)}")

# Add LoRA adapters
model = FastLanguageModel.get_peft_model(
model,
r = 128,
r = 128, # Choose any number > 0! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"embed_tokens", "lm_head"],
"embed_tokens", "lm_head",], # Add for continual pretraining
lora_alpha = 32,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = True,
loftq_config = None,
use_rslora = True, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)

# Set up tokenizer with chat template
tokenizer = get_chat_template(
tokenizer,
chat_template = "qwen-2.5",
)
tokenizer.eos_token = "<|im_end|>"
print(tokenizer.eos_token)
print(tokenizer.pad_token)

# Load dataset
ds = load_dataset("meta-math/MetaMathQA")
train_ds = ds['train']
# Load and process dataset
dataset = load_dataset("meta-math/MetaMathQA", split="train")

# Format prompts
def formatting_prompts_func(examples):
conversations = []
for query, response in zip(examples['query'], examples['response']):
Expand All @@ -82,52 +121,58 @@ def formatting_prompts_func(examples):
{"role": "assistant", "content": formatted_response}
]
conversations.append(conversation)
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)

texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
for convo in conversations]
return {"text": texts}

# <|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the total cost of purchasing equipment for all sixteen players on the football team, considering that each player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80?<|im_end|>\n<|im_start|>assistant\nEach player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80.\n\nSo the total cost for each player is $25 + $15.20 + $6.80 = $47.\n\nSince there are sixteen players on the football team, the total cost for all of them is 16 * $47 = $752.\n\nThe final answer is: $\\boxed{752}$<|im_end|>\n'

# Process dataset
formatted_dataset = train_ds.map(
formatting_prompts_func,
batched=True,
remove_columns=train_ds.column_names
)
print(len(formatted_dataset))
print(formatted_dataset[0])
dataset = dataset.map(formatting_prompts_func, batched=True, remove_columns=dataset.column_names)

# Debug tokenizer output - show examples
print("Example of tokenized output:")
print(dataset[5]["text"])
print("\nAnother example:")
print(dataset[100]["text"])

# Configure trainer
trainer = UnslothTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = formatted_dataset,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 8,
packing = True,

args = UnslothTrainingArguments(
per_device_train_batch_size = 8,
learning_rate = 5e-5,
embedding_learning_rate = 5e-6,
per_device_train_batch_size = 8, # With gradient_accumulation_steps=8 this gives effective batch size 64
gradient_accumulation_steps = 8,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
num_train_epochs = 3,
# learning_rate = 5e-6,
# embedding_learning_rate = 5e-7,
learning_rate = 8e-6,
embedding_learning_rate = 1e-6,
warmup_ratio = 0.1,
max_seq_length = 2048,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_torch_fused",
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "cosine",
logging_steps = 1,
seed = 3407,
output_dir = "outputs",
report_to = "wandb",
run_name = "metamath",
hub_strategy = "every_save",
save_strategy = "steps",
save_steps = 100,
hub_model_id = "rawsh/MetaMath-Qwen2.5-0.5b"
),
)

# Print GPU stats
# Set up wandb
# wandb.login(key="YOUR_WANDB_KEY") # Replace with your key
# wandb.init(project='metamath')

# Print initial GPU stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
Expand All @@ -137,11 +182,24 @@ def formatting_prompts_func(examples):
# Train
trainer_stats = trainer.train()

# Save model
# Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory/max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)

print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

# Save model to HuggingFace Hub
model.push_to_hub_merged(
"rawsh/MetaMath-Qwen2.5-0.5b",
"rawsh/MetaMath-Qwen2.5-0.5b", # Replace with your username
tokenizer,
save_method = "merged_16bit"
save_method="merged_16bit",
)

if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 271ae4a

Please sign in to comment.