-
Notifications
You must be signed in to change notification settings - Fork 358
[Tutorial] LLM integration #2832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/vmoens/105/base
Are you sure you want to change the base?
Changes from all commits
8f7e364
81d0414
81626ec
1145196
d5fc3b4
0fefc95
b4f74d3
6cba554
596870f
0b55333
1c063db
15a2f5f
df865b3
b403b60
1c53e09
e22e7f0
ff208bb
e019e25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from argparse import ArgumentParser | ||
|
||
import torch | ||
from datasets import load_dataset | ||
from grpo_utils import ( | ||
HF2vLLMLocalWeightUpdater, | ||
PrepareQuestion, | ||
ShapedCorrectnessReward, | ||
) | ||
from tensordict import TensorDict | ||
from torch.utils._pytree import tree_map | ||
from torch.utils.data import DataLoader | ||
from torchrl.collectors import SyncDataCollector | ||
from torchrl.data import LazyStackStorage, RayReplayBuffer, ReplayBuffer, SamplerWithoutReplacement | ||
from torchrl.envs import DataLoadingPrimer, KLRewardTransform, LLMEnv, StepCounter | ||
from torchrl.modules import from_hf_transformers, from_vllm | ||
from torchrl.objectives import ClipPPOLoss | ||
from transformers import AutoTokenizer, GPT2LMHeadModel | ||
from vllm import LLM | ||
|
||
parser = ArgumentParser() | ||
parser.add_argument("--dataset", type=str, default="gsm8k") | ||
parser.add_argument("--batch_size", type=int, default=4) | ||
parser.add_argument("--epochs", type=int, default=10) | ||
parser.add_argument("--repeats", type=int, default=10) | ||
parser.add_argument("--steps_per_batch", type=int, default=16) | ||
parser.add_argument("--optim_batch_size", type=int, default=4) | ||
|
||
|
||
def compute_mc_advantage(trajectories): | ||
# Get the question | ||
answer = trajectories["answer"] | ||
# Identify indices where the answers match | ||
answer_ids = tree_map(lambda string: hash(string), answer) | ||
answer_ids = torch.tensor(answer_ids) | ||
unique_qs = answer_ids.view(-1).unique() | ||
trajectories["advantage"] = trajectories["next", "reward"] * 0 | ||
for u in unique_qs: | ||
idx = answer_ids == u | ||
rewards = trajectories[idx]["next", "reward"] | ||
rewards = (rewards - rewards.mean()) / rewards.std().clamp(min=1e-4) | ||
trajectories.set_at_("advantage", rewards, idx) | ||
return trajectories | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
# Create env instance: | ||
# - Load the gsm8k dataset | ||
dataset = load_dataset(args.dataset, "main") | ||
train_dataset = dataset["train"] | ||
|
||
def collate_fn(batch): | ||
batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch]) | ||
batch.rename_key_("question", "text") | ||
return batch | ||
|
||
# LLM | ||
tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
# inference_model = GPT2LMHeadModel(GPT2Config()) | ||
inference_model = LLM("gpt2") | ||
tokenizer.pad_token = tokenizer.eos_token | ||
tokenizer.padding_side = "left" | ||
|
||
# Env | ||
dataloader = DataLoader( # noqa: TOR401 | ||
train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn | ||
) | ||
env = LLMEnv.from_dataloader( | ||
dataloader=dataloader, | ||
tokenizer=tokenizer, | ||
str2str=True, | ||
batch_size=(args.batch_size * args.repeats,), | ||
repeats=args.repeats, | ||
) | ||
for i, trsf in enumerate(env.transform): | ||
if isinstance(trsf, DataLoadingPrimer): | ||
env.insert_transform(i, PrepareQuestion()) | ||
break | ||
|
||
# Finally, we want the env to stop after the first step | ||
env.append_transform(StepCounter(max_steps=1)) | ||
|
||
policy = from_vllm( | ||
inference_model, | ||
tokenizer=tokenizer, | ||
from_text=False, | ||
generate=True, | ||
return_log_probs=True, | ||
) | ||
|
||
# Reward transform | ||
env.append_transform(ShapedCorrectnessReward(tokenizer=tokenizer)) | ||
|
||
# Ref model | ||
ref_model = GPT2LMHeadModel.from_pretrained("gpt2").eval() | ||
TensorDict.from_module(ref_model).data.to_module(ref_model) | ||
ref_model = from_hf_transformers( | ||
ref_model, | ||
tokenizer=tokenizer, | ||
from_text=False, | ||
generate=False, | ||
return_log_probs=True, | ||
) | ||
env.append_transform( | ||
KLRewardTransform(actor=ref_model, coef=0.1, log_prob_key="log_probs") | ||
) | ||
|
||
# replay buffer | ||
rb = ReplayBuffer( | ||
storage=LazyStackStorage(args.steps_per_batch), | ||
sampler=SamplerWithoutReplacement(), | ||
batch_size=args.optim_batch_size, | ||
) | ||
|
||
# Collector | ||
train_model = GPT2LMHeadModel.from_pretrained("gpt2").eval() | ||
collector = SyncDataCollector( | ||
env, | ||
policy, | ||
frames_per_batch=args.steps_per_batch, | ||
total_frames=1_000_000, | ||
local_weights_updater=HF2vLLMLocalWeightUpdater( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm curious to see how this would look when train_model and inference_model are sharded 😛 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah there is still some work to be done! |
||
hf_model=train_model, vllm_model=inference_model | ||
), | ||
) | ||
|
||
# Loss module | ||
policy_training = from_hf_transformers( | ||
train_model, | ||
tokenizer=tokenizer, | ||
from_text=False, | ||
generate=False, | ||
return_log_probs=True, | ||
) | ||
loss_fn = ClipPPOLoss( | ||
actor_network=policy_training, | ||
critic_network=None, | ||
critic_coef=0.0, | ||
functional=False, | ||
) | ||
loss_fn.set_keys(sample_log_prob="log_probs") | ||
loss_fn._set_in_keys() | ||
optim = torch.optim.Adam(loss_fn.parameters()) | ||
|
||
# loss_fn = ReinforceLoss( | ||
# actor_network=policy, | ||
# critic_network=None, | ||
# critic_coef=0.0, | ||
# ) | ||
|
||
for trajs in collector: | ||
trajs = trajs.reshape(-1) | ||
trajs = compute_mc_advantage(trajs) | ||
rb.extend(trajs) | ||
for _ in range(args.epochs): | ||
for batch in rb: | ||
loss = loss_fn(batch) | ||
loss_val = loss.mean(reduce=True) | ||
loss_val.backward() | ||
optim.step() | ||
optim.zero_grad() | ||
collector.update_policy_weights_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with the
.append_transform
API, is it possible to runShapedCorrectnessReward
andKLRewardTransform
in parallel?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not currently but we can think about it.
The difficulty is that in some cases you may have a transform that requires another one to do it's thing before.
We could imagine
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With just reward model and ref_model I think this API would work
But I don't think this API would be encompassing (or at least it might be tricky) for the case where there's a more complex graph of dependencies between transforms
n00b qn: with this API, who is responsible for consolidating the result tds from MyTransform0 and MyTransform1), would all the communications involved be wrapped in a transform?