From 2bc223cdd90de845ad10ba8f76e2841761be7da3 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 24 Jan 2023 16:15:56 -0800 Subject: [PATCH] [Pipeline] Register tie weights (#15) --- .github/workflows/ci_unit_test.yml | 2 +- ci/install_test_pkgs.sh | 2 +- ci/submit_job.py | 4 +- conftest.py | 8 +- docker/push.sh | 2 +- examples/gpt/model.py | 3 +- examples/gpt/schedule.py | 5 +- examples/opt/model.py | 3 +- examples/opt/schedule.py | 7 +- slapo/model_dialect/__init__.py | 1 + slapo/model_dialect/deepspeed/engine.py | 25 ++++ slapo/model_dialect/deepspeed/pipeline.py | 144 ++++++++++++++++++++- slapo/model_dialect/registry.py | 23 +++- slapo/pipeline.py | 88 ++++++------- slapo/schedule.py | 145 ++++++++++++++-------- tests/test_pipeline_partition.py | 61 ++++++--- tests/test_shard.py | 53 ++++++++ 17 files changed, 434 insertions(+), 142 deletions(-) create mode 100644 slapo/model_dialect/deepspeed/engine.py diff --git a/.github/workflows/ci_unit_test.yml b/.github/workflows/ci_unit_test.yml index 080c9cdc..52573d40 100644 --- a/.github/workflows/ci_unit_test.yml +++ b/.github/workflows/ci_unit_test.yml @@ -112,7 +112,7 @@ jobs: --source-ref ${{ needs.check_status.outputs.ref }} \ --repo ${{ needs.check_status.outputs.repo }} \ --wait \ - --command "bash ./ci/task_unit_test.sh" + --command "bash ./ci/install_test_pkgs.sh; bash ./ci/task_unit_test.sh" update_ci_badge: needs: [unit_test] # Run this job whatever the unit tests were success or not. diff --git a/ci/install_test_pkgs.sh b/ci/install_test_pkgs.sh index 78d8b7da..8fbbf0c2 100644 --- a/ci/install_test_pkgs.sh +++ b/ci/install_test_pkgs.sh @@ -4,4 +4,4 @@ python3 -m pip install black==22.10.0 python3 -m pip install transformers==4.25.1 --no-deps -python3 -m pip install pylint==2.14.0 astroid==2.11.6 +python3 -m pip install pylint==2.14.0 astroid==2.11.6 mock==4.0.3 diff --git a/ci/submit_job.py b/ci/submit_job.py index cba8def9..658937bc 100644 --- a/ci/submit_job.py +++ b/ci/submit_job.py @@ -327,8 +327,8 @@ def main(): response = aws_batch.describe_jobs(jobs=[job_id]) status = response["jobs"][0]["status"] if status in {"SUCCEEDED", "FAILED"}: - if status == "SUCCEEDED" and log_stream_name is None: - # If the job is succeeded within a print period so that + if log_stream_name is None: + # If the job is ended within a print period so that # we have not got the log stream name, we need to get it here. log_stream_name = response["jobs"][0]["container"]["logStreamName"] if log_stream_name: diff --git a/conftest.py b/conftest.py index b82ec7e8..f22b9c08 100644 --- a/conftest.py +++ b/conftest.py @@ -22,10 +22,12 @@ def init_dist(request): try: dist.init_process_group(backend="nccl") except Exception as err: - print(f"Skip === {str(err)}") - pytest.skip(f"Skip {__file__} because torch.distributed is not initialized") + print(f"Skip initializing dist group: {str(err)}") def destory_dist(): - dist.destroy_process_group() + try: + dist.destroy_process_group() + except Exception: + pass request.addfinalizer(destory_dist) diff --git a/docker/push.sh b/docker/push.sh index f72b5b12..73439c28 100644 --- a/docker/push.sh +++ b/docker/push.sh @@ -23,7 +23,7 @@ shift 1 PASSWORD="$1" shift 1 -LOCAL_IMAGE_NAME=slapo:latest +LOCAL_IMAGE_NAME=slapo-ci:latest REMOTE_IMAGE_NAME_VER=${DOCKER_HUB_ACCOUNT}/slapo:ci-${VERSION} REMOTE_IMAGE_NAME_LST=${DOCKER_HUB_ACCOUNT}/slapo:ci-latest diff --git a/examples/gpt/model.py b/examples/gpt/model.py index d6173b45..6f5cd115 100644 --- a/examples/gpt/model.py +++ b/examples/gpt/model.py @@ -64,7 +64,8 @@ def schedule_model( # Shard other parameters if MP group > 1. if sch.world_size > 1: replace_and_shard_mlp(sch[prefix], config, delay_init=delay_init) - shard_word_embedding(sch[prefix], config.vocab_size) + head_sch = sch["lm_head"] if "lm_head" in sch else None + shard_word_embedding(sch[prefix], head_sch, config.vocab_size) # Broadcast input to all devices within the MP group. # This is not required when running on Megatron. diff --git a/examples/gpt/schedule.py b/examples/gpt/schedule.py index 835696bb..b399f876 100644 --- a/examples/gpt/schedule.py +++ b/examples/gpt/schedule.py @@ -203,7 +203,7 @@ def remove_cast(sch, config, attn_path="h.N.attn.attention"): return cnt -def shard_word_embedding(sch, vocab_size, word_embed_name="wte"): +def shard_word_embedding(sch, head_sch, vocab_size, word_embed_name="wte"): if sch.world_size == 1: return @@ -232,6 +232,9 @@ def fwd_post_hook(_module, _input, output): sch[word_embed_name].sync(mode="fwd_post", sync_op_or_fn=fwd_post_hook) + # Shard output embedding. + head_sch.shard("weight", axis=0) + def shard_qkv( sch, diff --git a/examples/opt/model.py b/examples/opt/model.py index 54dde1d2..5aeb5332 100644 --- a/examples/opt/model.py +++ b/examples/opt/model.py @@ -65,7 +65,8 @@ def schedule_model( # Shard other parameters if MP group > 1. if sch.world_size > 1: replace_and_shard_mlp(sch[prefix], config, delay_init=delay_init) - shard_word_embedding(sch[prefix], config.vocab_size) + head_sch = sch["lm_head"] if "lm_head" in sch else None + shard_word_embedding(sch[prefix], head_sch, config.vocab_size) # Broadcast input to all devices within the MP group. # This is not required when running on Megatron. diff --git a/examples/opt/schedule.py b/examples/opt/schedule.py index 7b9c4701..1e2d417b 100644 --- a/examples/opt/schedule.py +++ b/examples/opt/schedule.py @@ -230,7 +230,9 @@ def remove_cast(sch, config, attn_path="h.N.attn.attention"): return cnt -def shard_word_embedding(sch, vocab_size, word_embed_name="decoder.embed_tokens"): +def shard_word_embedding( + sch, head_sch, vocab_size, word_embed_name="decoder.embed_tokens" +): if sch.world_size == 1: return @@ -259,6 +261,9 @@ def fwd_post_hook(_module, _input, output): sch[word_embed_name].sync(mode="fwd_post", sync_op_or_fn=fwd_post_hook) + # Shard output embedding. + head_sch.shard("weight", axis=0) + def shard_qkv( sch, diff --git a/slapo/model_dialect/__init__.py b/slapo/model_dialect/__init__.py index 30fdbeb8..ca2fc2be 100644 --- a/slapo/model_dialect/__init__.py +++ b/slapo/model_dialect/__init__.py @@ -5,4 +5,5 @@ from .megatron.utils import MegatronLogParser from .deepspeed.utils import DeepSpeedLogParser from .deepspeed.pipeline import DeepSpeedPipeStageWrapper +from .deepspeed.engine import init_ds_engine from .registry import get_all_dialects, get_dialect_cls diff --git a/slapo/model_dialect/deepspeed/engine.py b/slapo/model_dialect/deepspeed/engine.py new file mode 100644 index 00000000..c71f6b05 --- /dev/null +++ b/slapo/model_dialect/deepspeed/engine.py @@ -0,0 +1,25 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from ..registry import register_model_dialect +from ...logger import get_logger, INFO + +logger = get_logger("DS-Engine", INFO) + + +@register_model_dialect("deepspeed", "runtime_engine") +def init_ds_engine(model, **kwargs): + """Initialize the DeepSpeed engine.""" + import deepspeed + + if "config" not in kwargs: + raise ValueError("DeepSpeed config not provided.") + + # pylint: disable=unbalanced-tuple-unpacking + model, optimizer, _, _ = deepspeed.initialize( + model=model, + config=kwargs["config"], + model_parameters=[p for p in model.parameters() if p.requires_grad], + ) + return model, optimizer diff --git a/slapo/model_dialect/deepspeed/pipeline.py b/slapo/model_dialect/deepspeed/pipeline.py index 0d58c656..7e844abf 100644 --- a/slapo/model_dialect/deepspeed/pipeline.py +++ b/slapo/model_dialect/deepspeed/pipeline.py @@ -4,6 +4,7 @@ from enum import Enum import torch +from torch import distributed as dist from torch import fx import torch.nn as nn @@ -169,6 +170,31 @@ def tupleize(data): return tupleize(unordered_args) +def analyze_tie_ranks(tie_weight_groups, topology): + """Analyze the ranks of the tied weights for DeepSpeed pipeline.""" + tie_ranks = [] + tie_stages = [] + for tie_weight_set in tie_weight_groups: + tie_stage_ranks = [] + for _, stage_id in tie_weight_set: + stage_ranks = topology.filter_match(pipe=stage_id) + tie_stage_ranks.append(stage_ranks) + + num_ranks_same_stage = len(tie_stage_ranks[0]) + num_stages = len(tie_stage_ranks) + group_ranks = [] + for i in range(num_ranks_same_stage): + sub_group_ranks = [] + for j in range(num_stages): + sub_group_ranks.append(tie_stage_ranks[j][i]) + group_ranks.append(sorted(sub_group_ranks)) + tie_ranks.append(group_ranks) + + # Record the stage IDs of this tied weight. + tie_stages.append(sorted([stage_id for _, stage_id in tie_weight_set])) + return tie_ranks, tie_stages + + @register_model_dialect("deepspeed", "pipeline_stage") class DeepSpeedPipeStageWrapper(nn.Module): def __init__( @@ -298,13 +324,46 @@ def forward(self, *args, **kwargs): @register_model_dialect("deepspeed", "pipeline_engine") def deepspeed_pipe_engine( + sch_metadata, stage_modules, - topology, - param_dtype, **kwargs, ): + """DeepSpeed pipeline engine. + + Parameters + ---------- + sch_metadata : ScheduleMetadata + The schedule metadata. + + stage_modules : List[nn.Module] + The list of pipeline stage modules. + + **kwargs + The keyword arguments. Should include DeepSpeed related information, + such as "config", "loss_fn", "topology", "fp16". + + Returns + ------- + model : PipelineModule + The DeepSpeed pipeline module. + """ from deepspeed import pipe + # Sanity check + assert "config" in kwargs + if "topology" not in kwargs: + raise ValueError("Must provide topology for deepspeed pipeline") + topology = kwargs["topology"] + + if "loss_fn" not in kwargs: + raise ValueError("Must provide loss_fn for deepspeed pipeline") + if "fp16" in kwargs["config"] and kwargs["config"]["fp16"]["enabled"]: + param_dtype = torch.float16 + elif "bf16" in kwargs["config"] and kwargs["config"]["bf16"]["enabled"]: + param_dtype = torch.bfloat16 + else: + param_dtype = torch.float + model = pipe.PipelineModule( stage_modules, topology=topology, @@ -312,7 +371,82 @@ def deepspeed_pipe_engine( loss_fn=kwargs.get("loss_fn", None), param_dtype=param_dtype, ) - # TODO: tie weights - # tie_weight_groups=kwargs.get("tie_weight_groups", None) - # model.register_tie_weights() + + tie_weights = list(sch_metadata.tie_weights.values()) + if not tie_weights: + return model + + # Tie weights if needed. + if not hasattr(pipe, "TiedWeight"): + logger.warning( + "DeepSpeed pipeline runtime does not support TiedWeight. " + "The tie weight will be ignored." + ) + return model + + # Tie ranks and self stage ID. + tie_ranks, tie_stages = analyze_tie_ranks(tie_weights, topology) + global_rank = dist.get_rank() + + assert len(tie_ranks) == len(tie_weights) + for tie_rank, tie_stage, tie_weight in zip(tie_ranks, tie_stages, tie_weights): + # The group key for this tie weight set. Since this key is used + # in PyTorch ModuleDict, it cannot contain ".". + group_key = list(tie_weight)[0][0].replace(".", "_") + logger.info( + "Tie weights of %s", + ",".join([f"{name} in stage {sid}" for name, sid in tie_weight]), + ranks=0, + ) + my_stage_id = -1 + + # Identify the stage ID of this device. + # Ranks is a list of global ranks that includes one device per stage. + # Suppose we have 8 GPUs with TP=2 and PP=4, the device topology is + # Stage0: GPU0, GPU1 + # Stage1: GPU2, GPU3 + # Stage2: GPU4, GPU5 + # Stage3: GPU6, GPU7 + # Then when we tie weights in stage 0 and stage 3, the tie ranks would be + # [[0, 6], [1, 7]]. This means the rank 0, 1 are in the tie_stage[0]; + # while the rank 6, 7 are in the tie_stage[1]. + for ranks in tie_rank: + assert len(tie_stage) == len(ranks) + try: + stage_id_idx = ranks.index(global_rank) + my_stage_id = tie_stage[stage_id_idx] + break + except ValueError: + pass + + # Identify which weight in the stage of this device to tie. Suppose + # we tie wte.weight in stage 0 and linear.weight in stage 3, then + # rank 0 should have (module, weight_name) = (model.stage0.wte, "weight"); + # rank 3 should have (module, weight_name) = (model.stage3.linear, "weight"); + # other ranks should have (module, weight_name) = (None, None). + module, weight_name = None, None + found = False + for full_name, stage_id in tie_weight: + if stage_id == my_stage_id: + if found: + raise RuntimeError(f"Cannot tie two weights in the same stage") + assert isinstance(stage_modules[stage_id], DeepSpeedPipeStageWrapper) + module = stage_modules[stage_id].mod + for token in full_name.split(".")[:-1]: + module = getattr(module, token) + weight_name = full_name.split(".")[-1] + found = True + + if found: + # This device owns the stage that has this tie weight. + # Register the tie weight with the corresponding module and weight + # on this device. + assert module is not None and weight_name is not None + model.register_tied_weights( + pipe.TiedWeight(group_key, tie_rank, weight_name, module) + ) + else: + # Even this device is not in any stage, we have to register a tie + # weight to make sure all devices join the dist group. + model.register_tied_weights(pipe.TiedWeight(group_key, tie_rank, "", None)) return model diff --git a/slapo/model_dialect/registry.py b/slapo/model_dialect/registry.py index 933d45bc..1bbdf9aa 100644 --- a/slapo/model_dialect/registry.py +++ b/slapo/model_dialect/registry.py @@ -2,7 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 """Framework model dialect registration.""" -DIALECTS = {"pipeline_stage": {}, "pipeline_engine": {}, "log_parser": {}} +DIALECTS = { + "pipeline_stage": {}, + "pipeline_engine": {}, + "runtime_engine": {None: lambda model, **kwargs: (model, None)}, + "log_parser": {}, +} def register_model_dialect(target, cls_type): @@ -28,12 +33,20 @@ def get_all_dialects(cls_type): return DIALECTS[cls_type] -def get_dialect_cls(cls_type, target): +def get_dialect_cls(cls_type, target, allow_none=False): """Get the framework model dialect class.""" if cls_type not in DIALECTS: raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}") if target not in DIALECTS[cls_type]: - raise ValueError( - f"Target {target} not registered for {cls_type} model dialects" - ) + if allow_none: + if None in DIALECTS[cls_type]: + target = None + else: + raise ValueError( + f"Target {target} does not register default dialect for {cls_type}" + ) + else: + raise ValueError( + f"Target {target} not registered for {cls_type} model dialects" + ) return DIALECTS[cls_type][target] diff --git a/slapo/pipeline.py b/slapo/pipeline.py index 7e676b78..0b14cf7b 100644 --- a/slapo/pipeline.py +++ b/slapo/pipeline.py @@ -4,7 +4,6 @@ import operator from collections import OrderedDict -import torch from torch import fx from torch.fx.passes.split_module import split_module @@ -275,7 +274,7 @@ def get_itemized_name(node, suffix=""): return stage_id_2_arg_names, stage_id_2_name, liveness -def analyze_tie_weights(top_mod): +def analyze_tie_weights(top_mod, is_pipeline_partitioned): """Analyze if there is any tie weights (two weights in different module share the same memory) partitioned into different pipeline stages. @@ -286,17 +285,22 @@ def analyze_tie_weights(top_mod): 1) it should already be traced and partitioned, and 2) it should have a number of submodules that matches pipeline stages. + is_pipeline_partitioned : bool + Whether the module is partitioned for pipeline or not. If not, + then all tie weights will have stage ID 0. + Returns ------- - tie_groups : Dict[str, Set[Tuple[str, int]]] + tie_groups : Dict[int, Set[Tuple[str, int]]] + Mapping from the nn.Parameter object to the set of parameter names + that are tied to it. The set of parameter names is a tuple of + (parameter name, stage ID). The stage ID is 0 if the module is not + partitioned for pipeline. """ - # Mapping from parameter name to (the pipeline stage ID, the parameter object ID). + # Mapping from parameter name to (the pipeline stage ID, the parameter object). params = {} - # Mapping from the primary key (i.e., the first parameter name) - # of a tie group to the tie group. + # The result tie groups. tie_groups = {} - # Mapping from paramter name to the primary key of the tie group. - param_name_2_group_key = {} def _traverse_children(stage_id, prefix, curr_mod_name, curr_mod): full_prefix = f"{prefix}.{curr_mod_name}" if prefix else curr_mod_name @@ -306,69 +310,58 @@ def _traverse_children(stage_id, prefix, curr_mod_name, curr_mod): # only show up once in named_parameters. In other words, we cannot # identify tie weights with named_parameters(recurse=True). for curr_param_name, curr_param in curr_mod.named_parameters(): - curr_param_full_name = f"{full_prefix}.{curr_param_name}" - curr_param_id = id(curr_param) + curr_param_full_name = ( + f"{full_prefix}.{curr_param_name}" if full_prefix else curr_param_name + ) if curr_param_full_name in params: continue # Check if this parameter is tie to another one in a different stage. for target_param_full_name, ( target_stage, - target_param_id, + target_param, ) in params.items(): - if stage_id != target_stage and curr_param_id == target_param_id: - if target_param_full_name in param_name_2_group_key: + if is_pipeline_partitioned and stage_id == target_stage: + continue + if id(curr_param) == id(target_param): + if curr_param in tie_groups: # Get the tie group of the target parameter. - tie_group_key = param_name_2_group_key[target_param_full_name] - tie_group = tie_groups[tie_group_key] + tie_group = tie_groups[curr_param] else: # Create a new tie group, and use the target parameter name # as the primary key. - param_name_2_group_key[ - target_param_full_name - ] = target_param_full_name tie_group = set([(target_param_full_name, target_stage)]) - tie_groups[target_param_full_name] = tie_group + tie_groups[curr_param] = tie_group # Add the current parameter to the tie group. - param_name_2_group_key[ - curr_param_full_name - ] = target_param_full_name tie_group.add((curr_param_full_name, stage_id)) # Add this parameter for the rest analysis. - params[curr_param_full_name] = (stage_id, id(curr_param)) + params[curr_param_full_name] = (stage_id, curr_param) # Traverse children. for name, mod in curr_mod.named_children(): _traverse_children(stage_id, f"{full_prefix}", name, mod) - for stage_id, (mod_name, stage_mod) in enumerate(top_mod.named_children()): - _traverse_children(stage_id, "", mod_name, stage_mod) + if is_pipeline_partitioned: + for stage_id, (mod_name, stage_mod) in enumerate(top_mod.named_children()): + _traverse_children(stage_id, "", mod_name, stage_mod) + else: + _traverse_children(0, "", "", top_mod) # Explicitly delete the reference to parameters for safety. del params - return list(tie_groups.values()) + # Remove module name. + if is_pipeline_partitioned: + ret = {} + for param, tie_group_set in tie_groups.items(): + group = [(name[name.find(".") + 1 :], sid) for name, sid in tie_group_set] + ret[param] = group + else: + ret = tie_groups -def analyze_tie_ranks(tie_weight_groups, topology): - tie_ranks = [] - for tie_weight_set in tie_weight_groups: - tie_stage_ranks = [] - for _, stage_id in tie_weight_set: - stage_ranks = topology.filter_match(pipe=stage_id) - tie_stage_ranks.append(stage_ranks) - - num_ranks_same_stage = len(tie_stage_ranks[0]) - num_stages = len(tie_stage_ranks) - group_ranks = [] - for i in range(num_ranks_same_stage): - sub_group_ranks = [] - for j in range(num_stages): - sub_group_ranks.append(tie_stage_ranks[j][i]) - group_ranks.append(sorted(sub_group_ranks)) - tie_ranks.append(group_ranks) - return tie_ranks + return ret def generate_pipeline_partition(sch): @@ -405,9 +398,7 @@ def generate_pipeline_partition(sch): return sch -def generate_pipeline_modules( - sch, target, topology=None, param_dtype=torch.float16, **kwargs -): +def build_pipeline_model(sch, target, **kwargs): # Analyze pipelien module for liveness and arguments. partitioned_mod = sch.mod ( @@ -436,8 +427,7 @@ def generate_pipeline_modules( pipe_engine_fn = get_dialect_cls("pipeline_engine", target) return pipe_engine_fn( + sch.metadata, res_partition, - topology, - param_dtype, **kwargs, ) diff --git a/slapo/schedule.py b/slapo/schedule.py index ce03e87d..be586b6e 100644 --- a/slapo/schedule.py +++ b/slapo/schedule.py @@ -20,10 +20,10 @@ from torch.utils import checkpoint from .logger import get_logger +from .model_dialect import get_dialect_cls from .pipeline import ( analyze_tie_weights, - analyze_tie_ranks, - generate_pipeline_modules, + build_pipeline_model, generate_pipeline_partition, ) from .sharding import ( @@ -72,6 +72,14 @@ class ScheduleMetadata: # 2) Let each primitive derive metadata class. shard: dict[str, Any] = field(default_factory=lambda: DictWithValidation()) + # Tie weight analysis only at the top level module. + # tie_weights is a mapping from parameter object to the same + # parameter object. Note that the value may be changed during + # scheduling (e.g., sharding). + tie_weights: dict[nn.Parameter, nn.Parameter] = field( + default_factory=lambda: OrderedDict() + ) + # A set of paths to the modules that includes pipeline cutting annotations. # Note that we use ordered set to keep the order of the modules. pipeline_cutting_paths: dict[str, Any] = field( @@ -140,10 +148,22 @@ def __init__( self.parent = parent self.child = {} self.metadata = ScheduleMetadata() - # record original shape + + # Record original shapes. for param_name, param in mod.named_parameters(): self.metadata.base_params[param_name] = param.shape + if parent is None: + # Tie weight analysis only at the top level module. + # tie_weights is a mapping from parameter object to the same + # parameter object. Note that the value may be changed during + # scheduling (e.g., sharding). + for param in analyze_tie_weights(mod, False): + self.metadata.tie_weights[param] = param + else: + # Inherit tie_weights from parent. + self.metadata.tie_weights = parent.metadata.tie_weights + self.finalized = False @staticmethod @@ -211,8 +231,27 @@ def _shard(name, tensor): try: param = self.mod.get_parameter(tensor_name) - new_param, sharded_size = _shard(tensor_name, param) - self.mod.register_parameter(tensor_name, nn.Parameter(new_param)) + new_tensor, sharded_size = _shard(tensor_name, param) + if param in self.metadata.tie_weights: + if id(self.metadata.tie_weights[param]) != id(param): + # This parameter is tied to another parameter, and the other + # parameter is already sharded. In this case we directly + # register the sharded parameter to the module to keep them tied. + if new_tensor.shape != self.metadata.tie_weights[param].shape: + raise RuntimeError( + f"Parameter {tensor_name} in {self.path} is tied, " + "but they have different sharded shapes: " + f"{new_tensor.shape} vs " + f"{self.metadata.tie_weights[param].shape}" + ) + new_param = self.metadata.tie_weights[param] + else: + # The first parameter in this tie group is sharded. + new_param = nn.Parameter(new_tensor) + self.metadata.tie_weights[param] = new_param + else: + new_param = nn.Parameter(new_tensor) + self.mod.register_parameter(tensor_name, new_param) except AttributeError: buffer = self.mod.get_buffer(tensor_name) new_buffer, sharded_size = _shard(tensor_name, buffer) @@ -828,13 +867,27 @@ def is_module_list(module): def consolidate_model( sch: Schedule, - topology=None, + target: str, param_init_fn: Optional[Callable[[nn.Module], None]] = None, + **kwargs, ): + """Consolidate the model weights. + FIXME: When pipeline is enabled, this function only supports DeepSpeed + runtime because it relies on DeepSpeed topology. We should use dialects + in this function to make it general applicable. + """ + topology = kwargs.get("topology", None) if dist.is_initialized() and dist.get_world_size() > sch.world_size: - assert ( - topology is not None - ), f"topology={topology} must be given when there are multiple tensor paralel groups or pipeline parallelism is used" + if topology is None: + raise ValueError( + "topology must be given when there are multiple " + "tensor paralel groups or pipeline parallelism is used" + ) + if target != "deepspeed": + raise ValueError( + "Only deepspeed runtime is supported for now when there are multiple " + "tensor paralel groups or pipeline parallelism is used" + ) cnt_meta, cnt_materialized = 0, 0 # Since some parameters are attached to non-leaf modules, we need to @@ -853,8 +906,9 @@ def consolidate_model( global_ranks = [None] if cnt_meta != 0 or cnt_materialized != 0: if dist.is_initialized(): - # tackle with pipeline modules - # even the model does not use meta device, we still need to broadcast the weights to ensure consistency + # Tackle with pipeline modules. + # Even the model does not use meta device, we still need to broadcast + # the weights to ensure consistency global_rank = dist.get_rank() if topology is not None: # 1st DP: devices in the same bracket are in the same TP group @@ -899,7 +953,9 @@ def _init_module(sch: Schedule): sch.mod.reset_parameters() else: raise RuntimeError( - f"Module {sch.name} should have `reset_parameters` or `_init_weights` method or param_init_fn={param_init_fn} needs to be provided in order to support delay initialization" + f"Module {sch.name} should have `reset_parameters` or " + "`_init_weights` method or param_init_fn={param_init_fn} needs " + "to be provided in order to support delay initialization" ) def _consolidate_and_broadcast(sch: Schedule): @@ -975,64 +1031,43 @@ def _consolidate_and_broadcast(sch: Schedule): return sch +def init_target_engine(sch, target, **kwargs): + """Initialize the runtime engine for a specific target framework.""" + init_engine_fn = get_dialect_cls("runtime_engine", target, allow_none=True) + return init_engine_fn( + sch, + **kwargs, + ) + + def build( sch: Schedule, - topology=None, target=None, init_weights: Optional[Union[bool, Callable]] = True, **kwargs, ): - optimizer = None if sch.metadata.pipeline_cutting_paths: # pipeline stages will be wrapped into PipeStageWrapper sch = generate_pipeline_partition(sch) - # Analyzie tie weights before consolidation. - tie_weight_groups = analyze_tie_weights(sch.mod) - tie_ranks = analyze_tie_ranks(tie_weight_groups, topology) - for _, _ in tie_ranks: - pass + # Re-analyzie tie weights before consolidation. + sch.metadata.tie_weights = analyze_tie_weights( + sch.mod, is_pipeline_partitioned=True + ) + print(f"tie_weight_groups: {sch.metadata.tie_weights}") # delay initialization if init_weights: init_weight_fn = init_weights if isinstance(init_weights, Callable) else None - sch = consolidate_model(sch, topology, init_weight_fn) - - if target == "deepspeed": - assert "config" in kwargs - import deepspeed - - if sch.metadata.pipeline_cutting_paths: - # Sanity check - if topology is None: - raise ValueError("Must provide topology for deepspeed pipeline") - if "loss_fn" not in kwargs: - raise ValueError("Must provide loss_fn for deepspeed pipeline") - if ( - "fp16" not in kwargs["config"] - or not kwargs["config"]["fp16"]["enabled"] - ): - param_dtype = torch.float - else: - param_dtype = torch.float16 - - model = generate_pipeline_modules( - sch, - target, - topology, - param_dtype, - tie_weight_groups=tie_weight_groups, - **kwargs, - ) - else: - model = sch.mod + sch = consolidate_model(sch, target, init_weight_fn, **kwargs) - # pylint: disable=unbalanced-tuple-unpacking - model, optimizer, _, _ = deepspeed.initialize( - model=model, - config=kwargs["config"], - model_parameters=[p for p in model.parameters() if p.requires_grad], + if sch.metadata.pipeline_cutting_paths: + # Generate pipeline modules for a particular target. + model = build_pipeline_model( + sch, + target, + **kwargs, ) else: model = sch.mod - return model, optimizer + return init_target_engine(model, target, **kwargs) diff --git a/tests/test_pipeline_partition.py b/tests/test_pipeline_partition.py index bdaa7b03..464af90c 100644 --- a/tests/test_pipeline_partition.py +++ b/tests/test_pipeline_partition.py @@ -2,10 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 """Test pipeline partition related logic.""" +# pylint: disable=duplicate-code import pytest +from mock import MagicMock from torch import nn -from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology import slapo @@ -43,17 +44,33 @@ def forward(self, x): model.stage1.linear.weight = model.stage0.wte.weight model.stage2.linear.weight = model.stage0.wte.weight - tie_weights = slapo.pipeline.analyze_tie_weights(model) + # Analyze tie weights for a normal model. + tie_weights = slapo.pipeline.analyze_tie_weights(model, False) assert len(tie_weights) == 1 - assert len(tie_weights[0]) == 3 - assert ("stage0.wte.weight", 0) in tie_weights[0] - assert ("stage1.linear.weight", 1) in tie_weights[0] - assert ("stage2.linear.weight", 2) in tie_weights[0] + val = list(tie_weights.values())[0] + assert len(val) == 3 + assert ("stage0.wte.weight", 0) in val + assert ("stage1.linear.weight", 0) in val + assert ("stage2.linear.weight", 0) in val + # Analyze tie weights for a pipeline model. In this case, + # the forward in top module only runs each pipeline stage sequentially. + tie_weights = slapo.pipeline.analyze_tie_weights(model, True) -def test_analyze_tie_ranks(): - topology = PipeModelDataParallelTopology(num_pp=2, num_mp=1, num_dp=1) + assert len(tie_weights) == 1 + val = list(tie_weights.values())[0] + assert len(val) == 3 + assert ("wte.weight", 0) in val + assert ("linear.weight", 1) in val + assert ("linear.weight", 2) in val + + +def test_deepspeed_analyze_tie_ranks(): + # Mock deepspeed.runtime.pipe.topology.PipeModelDataParallelTopology + # This mocked topology assumes pp=4, tp=2, dp=1. + topology = MagicMock() + topology.filter_match = lambda pipe: [pipe * 2, pipe * 2 + 1] class Stage0(nn.Module): def __init__(self): @@ -77,23 +94,35 @@ def __init__(self): super().__init__() self.stage0 = Stage0() self.stage1 = StageN() + self.stage2 = StageN() + self.stage3 = StageN() def forward(self, x): - return self.stage1(self.stage0(x)) + return self.stage3(self.stage2(self.stage1(self.stage0(x)))) with slapo.init_empty_weights(): model = Model() # Tie weights - model.stage1.linear.weight = model.stage0.wte.weight - - tie_weights = slapo.pipeline.analyze_tie_weights(model) + model.stage3.linear.weight = model.stage0.wte.weight + model.stage2.linear.weight = model.stage1.linear.weight - tie_ranks = slapo.pipeline.analyze_tie_ranks(tie_weights, topology) + tie_weights = list(slapo.pipeline.analyze_tie_weights(model, True).values()) + tie_ranks, tie_stages = slapo.model_dialect.deepspeed.pipeline.analyze_tie_ranks( + tie_weights, topology + ) - assert len(tie_ranks) == 1 - assert len(tie_ranks[0]) == 1 + # Expected tie_ranks (order may vary): [[[0, 6], [1, 7]], [[2, 4], [3, 5]]] + assert len(tie_ranks) == 2 + assert len(tie_ranks[0]) == 2 assert len(tie_ranks[0][0]) == 2 - assert tie_ranks[0][0] == [0, 1] + assert [[0, 6], [1, 7]] in tie_ranks + assert [[2, 4], [3, 5]] in tie_ranks + + # Expected tie_stages (order should be the same as tie_ranks): [[0, 3], [1, 2]] + assert len(tie_stages) == 2 + assert len(tie_stages[0]) == 2 + assert tie_stages[tie_ranks.index([[0, 6], [1, 7]])] == [0, 3] + assert tie_stages[tie_ranks.index([[2, 4], [3, 5]])] == [1, 2] if __name__ == "__main__": diff --git a/tests/test_shard.py b/tests/test_shard.py index 9eb0c54c..cba611b5 100644 --- a/tests/test_shard.py +++ b/tests/test_shard.py @@ -12,6 +12,7 @@ import torch import torch.distributed as dist +from torch import nn from torch.nn import functional as F from torch.autograd import Variable import slapo @@ -313,5 +314,57 @@ def forward(self, data): verify_grads(model, path_and_grads) +def test_tie_weights(init_dist): + """Test whether the tie weights are preserved after sharding.""" + + class Stage0(nn.Module): + def __init__(self): + super().__init__() + self.wte = nn.Embedding(10, 10) + self.linear = nn.Linear(10, 10, bias=False) + + def forward(self, x): + return self.linear(self.wte(x)) + + class StageN(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10, bias=False) + + def forward(self, x): + return self.linear(x) + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.stage0 = Stage0() + self.stage1 = StageN() + self.stage2 = StageN() + + def forward(self, x): + return self.stage2(self.stage1(self.stage0(x))) + + with slapo.init_empty_weights(): + model = Model() + # Tie weights + model.stage1.linear.weight = model.stage0.wte.weight + model.stage2.linear.weight = model.stage0.wte.weight + + sch = slapo.create_schedule(model) + print(sch.metadata.tie_weights) + + assert id(sch.mod.stage0.wte.weight) == id(sch.mod.stage1.linear.weight) + assert id(sch.mod.stage0.wte.weight) == id(sch.mod.stage2.linear.weight) + + sch["stage0.wte"].shard("weight", axis=0) + sch["stage0.wte"].sync(mode="fwd_post", sync_op_or_fn="all_gather", axis=0) + sch["stage0.wte"].sync(mode="bwd_post", sync_op_or_fn="all_reduce") + sch["stage1.linear"].shard("weight", axis=0) + sch["stage2.linear"].shard("weight", axis=0) + + assert id(sch.mod.stage0.wte.weight) == id(sch.mod.stage1.linear.weight) + assert id(sch.mod.stage0.wte.weight) == id(sch.mod.stage2.linear.weight) + + if __name__ == "__main__": pytest.main([__file__])