Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cannot use hf models #1147

Open
LYMDLUT opened this issue Jan 6, 2025 · 0 comments
Open

cannot use hf models #1147

LYMDLUT opened this issue Jan 6, 2025 · 0 comments

Comments

@LYMDLUT
Copy link

LYMDLUT commented Jan 6, 2025

$ torchrun --nproc-per-node 4 pippy_llama.py

import os
import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.pipelining import ScheduleGPipe, PipelineStage

Grab the model

whole_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct", device_map="meta"
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
tokenizer.pad_token = tokenizer.eos_token

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.distributed.init_process_group(rank=rank, world_size=world_size)

Cut model by equal number of layers per rank

layers_per_rank = whole_model.config.num_hidden_layers // world_size
print(f"layers_per_rank = {layers_per_rank}")

stage_idx = rank
num_stages = world_size

def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
model = copy.deepcopy(whole_model)
model = whole_model
if not is_first:
model.model.embed_tokens = None

drop_layers = stop_layer is not None
num_layers = len(model.model.layers) - 1
for idx in range(num_layers, -1, -1):
    if f"layers.{idx}" == stop_layer:
        drop_layers = False
    if f"layers.{idx}" == start_layer:
        drop_layers = True
    if drop_layers:
        del model.model.layers[idx]
# drop_layers = start_layer is not None
# for name in list(model.model.layers.keys()):
#     # we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
#     if f"layers.{name}" == start_layer:
#         drop_layers = False
#     if f"layers.{name}" == stop_layer:
#         drop_layers = True
#     if drop_layers:
#         del model.model.layers[name]

if not is_last:
    model.model.norm = None
    model.lm_head = None

stage = PipelineStage(
    model,
    stage_idx,
    num_stages,
    device,
    #group=pp_mesh.get_group("pp"),
)
return stage, model

base_interval = whole_model.config.num_hidden_layers // num_stages
extra_layers = whole_model.config.num_hidden_layers % num_stages

splits = []
current_layer = 0
for i in range(num_stages - 1):
if i == 0:
current_layer += base_interval
else:
# Middle stages get an extra layer if there are any remaining
if extra_layers > 0:
current_layer += base_interval + 1
extra_layers -= 1
else:
current_layer += base_interval
splits.append("layers." + str(current_layer))

start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
stage_idx,
start_layer,
stop_layer,
is_first=stage_idx == 0,
is_last=stage_idx == num_stages - 1,
)
model_chunk.to_empty(device=device)

Run time inputs

full_batch_prompts = (
"How do you", "I like to", "Can I help", "You need to",
"The weather is", "I found a", "What is your", "You are so",
) # full batch size = 8
inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=True)["input_ids"].to(device)

schedule = ScheduleGPipe(stage, num_stages)

Run

if rank == 0:
schedule.step(inputs)
elif rank == world_size - 1:

output = schedule.step()
if output is not None:
    next_token_logits = output[:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1)
    print(tokenizer.batch_decode(next_token))

else:
schedule.step()

16d6183f8d29eb1acc7ba798dc1833d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant