Skip to content

Commit

Permalink
[Pipeline] Register tie weights (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Jan 25, 2023
1 parent f6b8608 commit 2bc223c
Show file tree
Hide file tree
Showing 17 changed files with 434 additions and 142 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ jobs:
--source-ref ${{ needs.check_status.outputs.ref }} \
--repo ${{ needs.check_status.outputs.repo }} \
--wait \
--command "bash ./ci/task_unit_test.sh"
--command "bash ./ci/install_test_pkgs.sh; bash ./ci/task_unit_test.sh"
update_ci_badge:
needs: [unit_test]
# Run this job whatever the unit tests were success or not.
Expand Down
2 changes: 1 addition & 1 deletion ci/install_test_pkgs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

python3 -m pip install black==22.10.0
python3 -m pip install transformers==4.25.1 --no-deps
python3 -m pip install pylint==2.14.0 astroid==2.11.6
python3 -m pip install pylint==2.14.0 astroid==2.11.6 mock==4.0.3
4 changes: 2 additions & 2 deletions ci/submit_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def main():
response = aws_batch.describe_jobs(jobs=[job_id])
status = response["jobs"][0]["status"]
if status in {"SUCCEEDED", "FAILED"}:
if status == "SUCCEEDED" and log_stream_name is None:
# If the job is succeeded within a print period so that
if log_stream_name is None:
# If the job is ended within a print period so that
# we have not got the log stream name, we need to get it here.
log_stream_name = response["jobs"][0]["container"]["logStreamName"]
if log_stream_name:
Expand Down
8 changes: 5 additions & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ def init_dist(request):
try:
dist.init_process_group(backend="nccl")
except Exception as err:
print(f"Skip === {str(err)}")
pytest.skip(f"Skip {__file__} because torch.distributed is not initialized")
print(f"Skip initializing dist group: {str(err)}")

def destory_dist():
dist.destroy_process_group()
try:
dist.destroy_process_group()
except Exception:
pass

request.addfinalizer(destory_dist)
2 changes: 1 addition & 1 deletion docker/push.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ shift 1
PASSWORD="$1"
shift 1

LOCAL_IMAGE_NAME=slapo:latest
LOCAL_IMAGE_NAME=slapo-ci:latest
REMOTE_IMAGE_NAME_VER=${DOCKER_HUB_ACCOUNT}/slapo:ci-${VERSION}
REMOTE_IMAGE_NAME_LST=${DOCKER_HUB_ACCOUNT}/slapo:ci-latest

Expand Down
3 changes: 2 additions & 1 deletion examples/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def schedule_model(
# Shard other parameters if MP group > 1.
if sch.world_size > 1:
replace_and_shard_mlp(sch[prefix], config, delay_init=delay_init)
shard_word_embedding(sch[prefix], config.vocab_size)
head_sch = sch["lm_head"] if "lm_head" in sch else None
shard_word_embedding(sch[prefix], head_sch, config.vocab_size)

# Broadcast input to all devices within the MP group.
# This is not required when running on Megatron.
Expand Down
5 changes: 4 additions & 1 deletion examples/gpt/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def remove_cast(sch, config, attn_path="h.N.attn.attention"):
return cnt


def shard_word_embedding(sch, vocab_size, word_embed_name="wte"):
def shard_word_embedding(sch, head_sch, vocab_size, word_embed_name="wte"):
if sch.world_size == 1:
return

Expand Down Expand Up @@ -232,6 +232,9 @@ def fwd_post_hook(_module, _input, output):

sch[word_embed_name].sync(mode="fwd_post", sync_op_or_fn=fwd_post_hook)

# Shard output embedding.
head_sch.shard("weight", axis=0)


def shard_qkv(
sch,
Expand Down
3 changes: 2 additions & 1 deletion examples/opt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def schedule_model(
# Shard other parameters if MP group > 1.
if sch.world_size > 1:
replace_and_shard_mlp(sch[prefix], config, delay_init=delay_init)
shard_word_embedding(sch[prefix], config.vocab_size)
head_sch = sch["lm_head"] if "lm_head" in sch else None
shard_word_embedding(sch[prefix], head_sch, config.vocab_size)

# Broadcast input to all devices within the MP group.
# This is not required when running on Megatron.
Expand Down
7 changes: 6 additions & 1 deletion examples/opt/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def remove_cast(sch, config, attn_path="h.N.attn.attention"):
return cnt


def shard_word_embedding(sch, vocab_size, word_embed_name="decoder.embed_tokens"):
def shard_word_embedding(
sch, head_sch, vocab_size, word_embed_name="decoder.embed_tokens"
):
if sch.world_size == 1:
return

Expand Down Expand Up @@ -259,6 +261,9 @@ def fwd_post_hook(_module, _input, output):

sch[word_embed_name].sync(mode="fwd_post", sync_op_or_fn=fwd_post_hook)

# Shard output embedding.
head_sch.shard("weight", axis=0)


def shard_qkv(
sch,
Expand Down
1 change: 1 addition & 0 deletions slapo/model_dialect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .megatron.utils import MegatronLogParser
from .deepspeed.utils import DeepSpeedLogParser
from .deepspeed.pipeline import DeepSpeedPipeStageWrapper
from .deepspeed.engine import init_ds_engine
from .registry import get_all_dialects, get_dialect_cls
25 changes: 25 additions & 0 deletions slapo/model_dialect/deepspeed/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from ..registry import register_model_dialect
from ...logger import get_logger, INFO

logger = get_logger("DS-Engine", INFO)


@register_model_dialect("deepspeed", "runtime_engine")
def init_ds_engine(model, **kwargs):
"""Initialize the DeepSpeed engine."""
import deepspeed

if "config" not in kwargs:
raise ValueError("DeepSpeed config not provided.")

# pylint: disable=unbalanced-tuple-unpacking
model, optimizer, _, _ = deepspeed.initialize(
model=model,
config=kwargs["config"],
model_parameters=[p for p in model.parameters() if p.requires_grad],
)
return model, optimizer
144 changes: 139 additions & 5 deletions slapo/model_dialect/deepspeed/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from enum import Enum
import torch
from torch import distributed as dist
from torch import fx
import torch.nn as nn

Expand Down Expand Up @@ -169,6 +170,31 @@ def tupleize(data):
return tupleize(unordered_args)


def analyze_tie_ranks(tie_weight_groups, topology):
"""Analyze the ranks of the tied weights for DeepSpeed pipeline."""
tie_ranks = []
tie_stages = []
for tie_weight_set in tie_weight_groups:
tie_stage_ranks = []
for _, stage_id in tie_weight_set:
stage_ranks = topology.filter_match(pipe=stage_id)
tie_stage_ranks.append(stage_ranks)

num_ranks_same_stage = len(tie_stage_ranks[0])
num_stages = len(tie_stage_ranks)
group_ranks = []
for i in range(num_ranks_same_stage):
sub_group_ranks = []
for j in range(num_stages):
sub_group_ranks.append(tie_stage_ranks[j][i])
group_ranks.append(sorted(sub_group_ranks))
tie_ranks.append(group_ranks)

# Record the stage IDs of this tied weight.
tie_stages.append(sorted([stage_id for _, stage_id in tie_weight_set]))
return tie_ranks, tie_stages


@register_model_dialect("deepspeed", "pipeline_stage")
class DeepSpeedPipeStageWrapper(nn.Module):
def __init__(
Expand Down Expand Up @@ -298,21 +324,129 @@ def forward(self, *args, **kwargs):

@register_model_dialect("deepspeed", "pipeline_engine")
def deepspeed_pipe_engine(
sch_metadata,
stage_modules,
topology,
param_dtype,
**kwargs,
):
"""DeepSpeed pipeline engine.
Parameters
----------
sch_metadata : ScheduleMetadata
The schedule metadata.
stage_modules : List[nn.Module]
The list of pipeline stage modules.
**kwargs
The keyword arguments. Should include DeepSpeed related information,
such as "config", "loss_fn", "topology", "fp16".
Returns
-------
model : PipelineModule
The DeepSpeed pipeline module.
"""
from deepspeed import pipe

# Sanity check
assert "config" in kwargs
if "topology" not in kwargs:
raise ValueError("Must provide topology for deepspeed pipeline")
topology = kwargs["topology"]

if "loss_fn" not in kwargs:
raise ValueError("Must provide loss_fn for deepspeed pipeline")
if "fp16" in kwargs["config"] and kwargs["config"]["fp16"]["enabled"]:
param_dtype = torch.float16
elif "bf16" in kwargs["config"] and kwargs["config"]["bf16"]["enabled"]:
param_dtype = torch.bfloat16
else:
param_dtype = torch.float

model = pipe.PipelineModule(
stage_modules,
topology=topology,
partition_method="uniform",
loss_fn=kwargs.get("loss_fn", None),
param_dtype=param_dtype,
)
# TODO: tie weights
# tie_weight_groups=kwargs.get("tie_weight_groups", None)
# model.register_tie_weights()

tie_weights = list(sch_metadata.tie_weights.values())
if not tie_weights:
return model

# Tie weights if needed.
if not hasattr(pipe, "TiedWeight"):
logger.warning(
"DeepSpeed pipeline runtime does not support TiedWeight. "
"The tie weight will be ignored."
)
return model

# Tie ranks and self stage ID.
tie_ranks, tie_stages = analyze_tie_ranks(tie_weights, topology)
global_rank = dist.get_rank()

assert len(tie_ranks) == len(tie_weights)
for tie_rank, tie_stage, tie_weight in zip(tie_ranks, tie_stages, tie_weights):
# The group key for this tie weight set. Since this key is used
# in PyTorch ModuleDict, it cannot contain ".".
group_key = list(tie_weight)[0][0].replace(".", "_")
logger.info(
"Tie weights of %s",
",".join([f"{name} in stage {sid}" for name, sid in tie_weight]),
ranks=0,
)
my_stage_id = -1

# Identify the stage ID of this device.
# Ranks is a list of global ranks that includes one device per stage.
# Suppose we have 8 GPUs with TP=2 and PP=4, the device topology is
# Stage0: GPU0, GPU1
# Stage1: GPU2, GPU3
# Stage2: GPU4, GPU5
# Stage3: GPU6, GPU7
# Then when we tie weights in stage 0 and stage 3, the tie ranks would be
# [[0, 6], [1, 7]]. This means the rank 0, 1 are in the tie_stage[0];
# while the rank 6, 7 are in the tie_stage[1].
for ranks in tie_rank:
assert len(tie_stage) == len(ranks)
try:
stage_id_idx = ranks.index(global_rank)
my_stage_id = tie_stage[stage_id_idx]
break
except ValueError:
pass

# Identify which weight in the stage of this device to tie. Suppose
# we tie wte.weight in stage 0 and linear.weight in stage 3, then
# rank 0 should have (module, weight_name) = (model.stage0.wte, "weight");
# rank 3 should have (module, weight_name) = (model.stage3.linear, "weight");
# other ranks should have (module, weight_name) = (None, None).
module, weight_name = None, None
found = False
for full_name, stage_id in tie_weight:
if stage_id == my_stage_id:
if found:
raise RuntimeError(f"Cannot tie two weights in the same stage")
assert isinstance(stage_modules[stage_id], DeepSpeedPipeStageWrapper)
module = stage_modules[stage_id].mod
for token in full_name.split(".")[:-1]:
module = getattr(module, token)
weight_name = full_name.split(".")[-1]
found = True

if found:
# This device owns the stage that has this tie weight.
# Register the tie weight with the corresponding module and weight
# on this device.
assert module is not None and weight_name is not None
model.register_tied_weights(
pipe.TiedWeight(group_key, tie_rank, weight_name, module)
)
else:
# Even this device is not in any stage, we have to register a tie
# weight to make sure all devices join the dist group.
model.register_tied_weights(pipe.TiedWeight(group_key, tie_rank, "", None))
return model
23 changes: 18 additions & 5 deletions slapo/model_dialect/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
"""Framework model dialect registration."""

DIALECTS = {"pipeline_stage": {}, "pipeline_engine": {}, "log_parser": {}}
DIALECTS = {
"pipeline_stage": {},
"pipeline_engine": {},
"runtime_engine": {None: lambda model, **kwargs: (model, None)},
"log_parser": {},
}


def register_model_dialect(target, cls_type):
Expand All @@ -28,12 +33,20 @@ def get_all_dialects(cls_type):
return DIALECTS[cls_type]


def get_dialect_cls(cls_type, target):
def get_dialect_cls(cls_type, target, allow_none=False):
"""Get the framework model dialect class."""
if cls_type not in DIALECTS:
raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}")
if target not in DIALECTS[cls_type]:
raise ValueError(
f"Target {target} not registered for {cls_type} model dialects"
)
if allow_none:
if None in DIALECTS[cls_type]:
target = None
else:
raise ValueError(
f"Target {target} does not register default dialect for {cls_type}"
)
else:
raise ValueError(
f"Target {target} not registered for {cls_type} model dialects"
)
return DIALECTS[cls_type][target]
Loading

0 comments on commit 2bc223c

Please sign in to comment.