-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
464 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,30 @@ | ||
######################## | ||
# This script is modified from the TRL package https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py | ||
# This script is designed for the reward modeling with Gemma model but can also be applied to any models with a chat template and an official pad token | ||
# If you have any question, feel free to send me an email via [email protected] | ||
######################## | ||
from dataclasses import dataclass, field | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
# import evaluate | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from datasets import load_dataset | ||
# from peft import LoraConfig, TaskType, get_peft_model | ||
from transformers import ( | ||
AutoModelForSequenceClassification, | ||
AutoTokenizer, | ||
HfArgumentParser, | ||
PreTrainedTokenizerBase, | ||
Trainer, | ||
TrainerCallback, | ||
TrainingArguments, | ||
) | ||
from transformers.utils import PaddingStrategy | ||
|
||
import pdb | ||
|
||
import random | ||
from collections import Counter | ||
|
||
# Define and parse arguments. | ||
@dataclass | ||
class ScriptArguments: | ||
""" | ||
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train. | ||
""" | ||
local_rank: Optional[int] = field( | ||
default=-1, metadata={"help": "Used for multi-gpu"}) | ||
|
||
default=-1, metadata={"help": "Used for multi-gpu"} | ||
) | ||
deepspeed: Optional[str] = field( | ||
# default="dp3.json", | ||
default=None, | ||
metadata={ | ||
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU." | ||
|
@@ -45,12 +33,12 @@ class ScriptArguments: | |
per_device_train_batch_size: Optional[int] = field(default=4) | ||
per_device_eval_batch_size: Optional[int] = field(default=4) | ||
gradient_accumulation_steps: Optional[int] = field(default=32) | ||
# learning_rate: Optional[float] = field(default=2e-6) | ||
# embedding_learning_rate: Optional[float] = field(default=1e-6) | ||
learning_rate: Optional[float] = field(default=1e-5) | ||
weight_decay: Optional[float] = field(default=0.001) | ||
model_name: Optional[str] = field( | ||
# default="google/gemma-2-9b", | ||
default="google/gemma-2-2b", | ||
# default="Qwen/Qwen2.5-1.5B", | ||
metadata={ | ||
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc." | ||
}, | ||
|
@@ -63,6 +51,7 @@ class ScriptArguments: | |
) | ||
num_train_epochs: Optional[int] = field( | ||
default=1, | ||
# default=3, | ||
metadata={"help": "The number of training epochs for the reward model."}, | ||
) | ||
train_set_path: Optional[str] = field( | ||
|
@@ -75,26 +64,21 @@ class ScriptArguments: | |
) | ||
output_path: Optional[str] = field( | ||
default="./mirrorgemma-2-2b-prm-base", | ||
# default="./gemma-2-9b", | ||
metadata={"help": "The dir for output model"}, | ||
) | ||
gradient_checkpointing: Optional[bool] = field( | ||
default=True, | ||
metadata={"help": "Enables gradient checkpointing."}, | ||
) | ||
optim: Optional[str] = field( | ||
# default="adamw_hf", | ||
# default="paged_adamw_32bit", | ||
default="adamw_torch_fused", | ||
# default="adamw_bnb_8bit", | ||
metadata={"help": "The optimizer to use."}, | ||
) | ||
lr_scheduler_type: Optional[str] = field( | ||
default="cosine", | ||
metadata={"help": "The lr scheduler"}, | ||
) | ||
max_length: Optional[int] = field(default=8192) | ||
|
||
save_every_steps: Optional[int] = field( | ||
default=999999, | ||
metadata={"help": "Save the model every x steps"}, | ||
|
@@ -105,7 +89,6 @@ class ScriptArguments: | |
) | ||
|
||
def build_dataset(tokenizer, train_path, eval_path): | ||
|
||
def tokenize(sample): | ||
question = sample['question'] | ||
steps = sample['steps'] | ||
|
@@ -114,24 +97,103 @@ def tokenize(sample): | |
formatted_steps = "\n\n".join(steps) | ||
full_text = f"{question}\n\n{formatted_steps}" | ||
|
||
tokenized = tokenizer(full_text, truncation=True, max_length=tokenizer.model_max_length) | ||
tokenized = tokenizer( | ||
full_text, | ||
truncation=True, | ||
max_length=tokenizer.model_max_length, | ||
) | ||
|
||
sample["input_ids"] = tokenized["input_ids"] | ||
sample["attention_mask"] = tokenized["attention_mask"] | ||
sample["reward"] = final_step_reward | ||
return sample | ||
|
||
ds = load_dataset(train_path, split="train").shuffle(seed=42) | ||
ds = ds.map(tokenize, num_proc=24) | ||
|
||
train_dataset = ds | ||
# eval_dataset = load_dataset(eval_path, split="train").shuffle(seed=42).select(range(500)) | ||
eval_dataset = load_dataset(eval_path, split="train").shuffle(seed=42).select(range(10000)) | ||
# TODO: FIX | ||
return train_dataset, eval_dataset | ||
|
||
# Load and shuffle the training dataset | ||
ds_train = load_dataset(train_path, split="train").shuffle(seed=42) | ||
ds_train = ds_train.map(tokenize, num_proc=24) | ||
|
||
# Step 2: Assign bin number to each sample in training data | ||
def assign_bin(example): | ||
final_step_reward = example['final_step_reward'] | ||
# Calculate bin number (bins: 0.0-0.1 => bin 0, ..., 0.9-1.0 => bin 9) | ||
bin_number = int(final_step_reward * 10) | ||
if bin_number == 10: | ||
bin_number = 9 # Handle the edge case where final_step_reward == 1.0 | ||
example['bin'] = bin_number | ||
return example | ||
|
||
ds_train = ds_train.map(assign_bin, num_proc=24) | ||
|
||
# Step 3: Get counts of samples in each bin for training data | ||
bin_counts_train = Counter(ds_train['bin']) | ||
print("Training bin counts before undersampling:", bin_counts_train) | ||
|
||
# Determine the minimum count across all bins in training data | ||
min_count_train = min(bin_counts_train.values()) | ||
print("Training minimum count per bin:", min_count_train) | ||
|
||
# Step 4: Create a mapping from bin to indices for training data | ||
bin_to_indices_train = {i: [] for i in range(10)} # Bins 0 to 9 | ||
for idx, bin_number in enumerate(ds_train['bin']): | ||
bin_to_indices_train[bin_number].append(idx) | ||
|
||
# Randomly sample min_count_train indices per bin for training data | ||
random.seed(42) | ||
selected_indices_train = [] | ||
for bin_number, indices in bin_to_indices_train.items(): | ||
if len(indices) >= min_count_train: | ||
sampled_indices = random.sample(indices, min_count_train) | ||
else: | ||
sampled_indices = indices # Keep all samples if less than min_count_train | ||
selected_indices_train.extend(sampled_indices) | ||
|
||
# Shuffle the selected indices to mix the data | ||
random.shuffle(selected_indices_train) | ||
|
||
# Step 5: Create the balanced training dataset | ||
train_dataset = ds_train.select(selected_indices_train) | ||
print("Total training samples after undersampling:", len(train_dataset)) | ||
|
||
# Now, build the evaluation dataset | ||
# Load and shuffle the evaluation dataset | ||
ds_eval = load_dataset(eval_path, split="train").shuffle(seed=42) | ||
ds_eval = ds_eval.map(tokenize, num_proc=24) | ||
|
||
# Assign bins to evaluation data | ||
ds_eval = ds_eval.map(assign_bin, num_proc=24) | ||
|
||
# Get counts of samples in each bin for evaluation data | ||
bin_counts_eval = Counter(ds_eval['bin']) | ||
print("Evaluation bin counts before undersampling:", bin_counts_eval) | ||
|
||
# Determine the minimum count per bin for evaluation data | ||
# Set it to be 10% of min_count_train, at least 1 | ||
eval_min_count_per_bin = max(1, int(min_count_train * 0.1)) | ||
print("Evaluation minimum count per bin:", eval_min_count_per_bin) | ||
|
||
# Create a mapping from bin to indices for evaluation data | ||
bin_to_indices_eval = {i: [] for i in range(10)} # Bins 0 to 9 | ||
for idx, bin_number in enumerate(ds_eval['bin']): | ||
bin_to_indices_eval[bin_number].append(idx) | ||
|
||
# Randomly sample eval_min_count_per_bin indices per bin for evaluation data | ||
selected_indices_eval = [] | ||
for bin_number, indices in bin_to_indices_eval.items(): | ||
if len(indices) >= eval_min_count_per_bin: | ||
sampled_indices = random.sample(indices, eval_min_count_per_bin) | ||
else: | ||
sampled_indices = indices # Keep all samples if less than eval_min_count_per_bin | ||
selected_indices_eval.extend(sampled_indices) | ||
|
||
# Shuffle the selected indices to mix the data | ||
random.shuffle(selected_indices_eval) | ||
|
||
# Create the balanced evaluation dataset | ||
eval_dataset = ds_eval.select(selected_indices_eval) | ||
print("Total evaluation samples after undersampling:", len(eval_dataset)) | ||
|
||
return train_dataset, eval_dataset | ||
|
||
# We need to define a special data collator that batches the data in our j vs k format. | ||
@dataclass | ||
class RewardDataCollatorWithPadding: | ||
tokenizer: AutoTokenizer | ||
|
@@ -160,15 +222,12 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: | |
} | ||
return batch | ||
|
||
|
||
# Define the trainer | ||
def compute_metrics(eval_pred): | ||
predictions = eval_pred.predictions | ||
predictions = eval_pred.predictions.squeeze() | ||
labels = eval_pred.label_ids | ||
mse = np.mean((predictions - labels) ** 2) | ||
return {"mse": mse} | ||
|
||
|
||
class RewardTrainer(Trainer): | ||
def compute_loss(self, model, inputs, return_outputs=False): | ||
rewards = model( | ||
|
@@ -180,36 +239,32 @@ def compute_loss(self, model, inputs, return_outputs=False): | |
return loss, {"rewards": rewards} | ||
return loss | ||
|
||
|
||
|
||
def train_reward_model(): | ||
# parser = HfArgumentParser(ScriptArguments) | ||
# script_args = parser.parse_args_into_dataclasses()[0] | ||
|
||
# hardcode args | ||
# Hardcode args (or you can parse arguments) | ||
script_args = ScriptArguments() | ||
|
||
# Load the value-head model and tokenizer. | ||
# Load the model and tokenizer | ||
tokenizer_name = script_args.model_name | ||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True) | ||
|
||
# Adjusted according to the base model | ||
# Need to do this for the models that don't have an official pad token. | ||
tokenizer.truncation_side = "left" | ||
tokenizer.model_max_length = script_args.max_length | ||
|
||
# Get the dataset | ||
# Get the datasets | ||
train_path = script_args.train_set_path | ||
eval_path = script_args.eval_set_path | ||
output_name = script_args.output_path | ||
|
||
train_dataset, eval_dataset = build_dataset(tokenizer, train_path, eval_path) | ||
print("Training set: ", len(train_dataset), " Eval set: ", len(eval_dataset)) | ||
print("Training set size:", len(train_dataset)) | ||
print("Evaluation set size:", len(eval_dataset)) | ||
|
||
# Define the trainer | ||
# Define the training arguments | ||
training_args = TrainingArguments( | ||
output_dir=output_name, | ||
learning_rate=script_args.learning_rate, | ||
# embedding_learning_rate=script_args.embedding_learning_rate, | ||
per_device_train_batch_size=script_args.per_device_train_batch_size, | ||
per_device_eval_batch_size=script_args.per_device_eval_batch_size, | ||
num_train_epochs=script_args.num_train_epochs, | ||
|
@@ -231,48 +286,40 @@ def train_reward_model(): | |
lr_scheduler_type=script_args.lr_scheduler_type, | ||
warmup_ratio=0.03, | ||
report_to='wandb', | ||
# compile | ||
torch_compile=True | ||
torch_compile=True, | ||
) | ||
|
||
# enable if you want to train with lora | ||
# peft_config = LoraConfig( | ||
# task_type=TaskType.SEQ_CLS, | ||
# inference_mode=False, | ||
# r=8, | ||
# lora_alpha=32, | ||
# lora_dropout=0.1, | ||
# ) | ||
|
||
model = AutoModelForSequenceClassification.from_pretrained( | ||
script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16, use_flash_attention_2=True, | ||
script_args.model_name, | ||
num_labels=1, | ||
torch_dtype=torch.bfloat16, | ||
use_flash_attention_2=True, | ||
) | ||
# model = get_peft_model(model, peft_config) | ||
# model.print_trainable_parameters() | ||
|
||
model.config.use_cache = not script_args.gradient_checkpointing | ||
num_proc = 24 # Can adjust to be higher if you have more processors. | ||
original_columns = train_dataset.column_names | ||
|
||
|
||
# Train the model, woohoo. | ||
# Initialize the trainer | ||
trainer = RewardTrainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
compute_metrics=compute_metrics, | ||
data_collator=RewardDataCollatorWithPadding( | ||
tokenizer=tokenizer, max_length=script_args.max_length), | ||
tokenizer=tokenizer, max_length=script_args.max_length | ||
), | ||
) | ||
|
||
# Start training | ||
trainer.train() | ||
|
||
print("Saving last checkpoint of the model") | ||
#model.save_pretrained(output_name + "/last_checkpoint") | ||
trainer.save_model(output_name + "/last_checkpoint") | ||
tokenizer.save_pretrained(output_name + "/last_checkpoint") | ||
|
||
# push to hub | ||
# TODO: modal secret | ||
trainer.push_to_hub("rawsh/mirrorgemma-2-2b-PRM-base") | ||
# Push the model to Hugging Face Hub | ||
# Ensure you have the necessary permissions and authentication | ||
trainer.push_to_hub("rawsh/mirrorgemma-2-2b-PRM-base") | ||
|
||
if __name__ == "__main__": | ||
train_reward_model() |
Oops, something went wrong.