You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
model = copy.deepcopy(whole_model)
model = whole_model
if not is_first:
model.model.embed_tokens = None
drop_layers = stop_layer is not None
num_layers = len(model.model.layers) - 1
for idx in range(num_layers, -1, -1):
if f"layers.{idx}" == stop_layer:
drop_layers = False
if f"layers.{idx}" == start_layer:
drop_layers = True
if drop_layers:
del model.model.layers[idx]
# drop_layers = start_layer is not None
# for name in list(model.model.layers.keys()):
# # we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
# if f"layers.{name}" == start_layer:
# drop_layers = False
# if f"layers.{name}" == stop_layer:
# drop_layers = True
# if drop_layers:
# del model.model.layers[name]
if not is_last:
model.model.norm = None
model.lm_head = None
stage = PipelineStage(
model,
stage_idx,
num_stages,
device,
#group=pp_mesh.get_group("pp"),
)
return stage, model
splits = []
current_layer = 0
for i in range(num_stages - 1):
if i == 0:
current_layer += base_interval
else:
# Middle stages get an extra layer if there are any remaining
if extra_layers > 0:
current_layer += base_interval + 1
extra_layers -= 1
else:
current_layer += base_interval
splits.append("layers." + str(current_layer))
full_batch_prompts = (
"How do you", "I like to", "Can I help", "You need to",
"The weather is", "I found a", "What is your", "You are so",
) # full batch size = 8
inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=True)["input_ids"].to(device)
$ torchrun --nproc-per-node 4 pippy_llama.py
import os
import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.pipelining import ScheduleGPipe, PipelineStage
Grab the model
whole_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct", device_map="meta"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.distributed.init_process_group(rank=rank, world_size=world_size)
Cut model by equal number of layers per rank
layers_per_rank = whole_model.config.num_hidden_layers // world_size
print(f"layers_per_rank = {layers_per_rank}")
stage_idx = rank
num_stages = world_size
def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
model = copy.deepcopy(whole_model)
model = whole_model
if not is_first:
model.model.embed_tokens = None
base_interval = whole_model.config.num_hidden_layers // num_stages
extra_layers = whole_model.config.num_hidden_layers % num_stages
splits = []
current_layer = 0
for i in range(num_stages - 1):
if i == 0:
current_layer += base_interval
else:
# Middle stages get an extra layer if there are any remaining
if extra_layers > 0:
current_layer += base_interval + 1
extra_layers -= 1
else:
current_layer += base_interval
splits.append("layers." + str(current_layer))
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
stage_idx,
start_layer,
stop_layer,
is_first=stage_idx == 0,
is_last=stage_idx == num_stages - 1,
)
model_chunk.to_empty(device=device)
Run time inputs
full_batch_prompts = (
"How do you", "I like to", "Can I help", "You need to",
"The weather is", "I found a", "What is your", "You are so",
) # full batch size = 8
inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=True)["input_ids"].to(device)
schedule = ScheduleGPipe(stage, num_stages)
Run
if rank == 0:
schedule.step(inputs)
elif rank == world_size - 1:
else:
schedule.step()
The text was updated successfully, but these errors were encountered: