diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py index 4754212c1914..5d44530e7edd 100644 --- a/colossalai/pipeline/__init__.py +++ b/colossalai/pipeline/__init__.py @@ -1,11 +1,12 @@ from .p2p import PipelineP2PCommunication -from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule +from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler from .stage_manager import PipelineStageManager __all__ = [ "PipelineSchedule", "OneForwardOneBackwardSchedule", "InterleavedSchedule", + "ZeroBubbleVPipeScheduler", "PipelineP2PCommunication", "PipelineStageManager", ] diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py index 6845dc23753b..05dd24e8169e 100644 --- a/colossalai/pipeline/schedule/__init__.py +++ b/colossalai/pipeline/schedule/__init__.py @@ -1,9 +1,11 @@ from .base import PipelineSchedule from .interleaved_pp import InterleavedSchedule from .one_f_one_b import OneForwardOneBackwardSchedule +from .zero_bubble_pp import ZeroBubbleVPipeScheduler __all__ = [ "PipelineSchedule", "OneForwardOneBackwardSchedule", "InterleavedSchedule", + "ZeroBubbleVPipeScheduler", ] diff --git a/colossalai/pipeline/schedule/v_schedule.py b/colossalai/pipeline/schedule/v_schedule.py new file mode 100644 index 000000000000..0d083c610ea4 --- /dev/null +++ b/colossalai/pipeline/schedule/v_schedule.py @@ -0,0 +1,468 @@ +# Refer from Zero Bubble Pipeline Parallelism. +# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism +# Paper: https://arxiv.org/abs/2401.10241 + +from collections import deque +from dataclasses import dataclass + + +@dataclass(eq=True, frozen=True) +class ScheduledNode: + type: str + chunk: int + stage: int + minibatch: int + start_time: int + completion_time: int + rollback: bool = False + + +class PipelineGraph(object): + """PipelineGraph""" + + def __init__( + self, + n_stage, + n_micro, + f_cost, + b_cost, + w_cost, + c_cost, + f_mem, + b_mem, + w_mem, + max_mem=None, + ): + self.n_node = 6 * n_stage * n_micro + self.n_stage = n_stage + self.n_micro = n_micro + self.f_cost = f_cost + self.b_cost = b_cost + self.w_cost = w_cost + self.c_cost = c_cost + self.f_mem = f_mem + self.b_mem = b_mem + self.w_mem = w_mem + self.fbw_cost = [f_cost, b_cost, w_cost] + self.fbw_mem = [f_mem, b_mem, w_mem] + self.max_mem = max_mem or f_mem * self.n_stage * 2 + + def get_id(self, cat, chunk, stage, micro): + return ( + cat * 2 * self.n_stage * self.n_micro + chunk * self.n_stage * self.n_micro + stage * self.n_micro + micro + ) + + def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None): + count = [] + for i in range(self.n_stage): + count.append([0] * 6) + + end_time = [-1] * self.n_node + cur_time = [0] * self.n_stage + mem = [0] * self.n_stage + stage_bubble = [0] * self.n_stage + pending_w = [deque() for _ in range(self.n_stage)] + schedule = [[] for _ in range(self.n_stage)] + stage_str = [" " * i for i in range(self.n_stage)] + + if approved_bubble is None: + approved_bubble = [-1] * self.n_stage + max_approved_bubble = max(approved_bubble) + + def get_max_stage_bubble(stage=-1): + max_stage_bubble = 0 + for bb in stage_bubble: + max_stage_bubble = max(max_stage_bubble, bb) + if stage >= 0: + max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage]) + return max_stage_bubble + + def put_w(stage): + assert len(pending_w[stage]) > 0 + _, chunk_, _ = pending_w[stage].popleft() + put(2, chunk_, stage) + + def put(cat, chunk, stage, assert_cnt=True): + _tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat] + _cnt = count[stage][cat * 2 + chunk] + # assert _cnt < self.n_micro + if _cnt >= self.n_micro: + if not assert_cnt: + stage_str[stage] += " " + cur_time[stage] = _tmp # TODO + return + assert False + assert mem[stage] + self.fbw_mem[cat] <= self.max_mem + stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1))) + if cat > 0 or chunk > 0: + last_id = cat * 2 + chunk - 1 + if cat < 2: + # if end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] < 0: + # print(cat, chunk, stage, _cnt) + # self.print_details(end_time) + assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0 + else: + assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0 + if chunk == 1 and cat < 2: + if stage < self.n_stage - 1: + _fa_id = self.get_id(cat, chunk, stage + 1, _cnt) + assert end_time[_fa_id] >= 0 + _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat]) + if chunk == 0 and cat < 2: + if stage > 0: + _fa_id = self.get_id(cat, chunk, stage - 1, _cnt) + # if end_time[_fa_id] < 0: + # print(cat, chunk, stage, _cnt) + # self.print_details(end_time) + assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}" + _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat]) + _id = self.get_id(cat, chunk, stage, _cnt) + if count[stage][0] > 0: + stage_bubble[stage] += _tmp - _no_bubble + end_time[_id] = _tmp + cur_time[stage] = _tmp + mem[stage] += self.fbw_mem[cat] + # noinspection PyTypeChecker + schedule[stage].append((cat, chunk, _cnt)) + if cat == 1: + pending_w[stage].append((2, chunk, _cnt)) + count[stage][cat * 2 + chunk] += 1 + + # for _ in range(2 * self.n_stage): + # for i in range(self.n_stage): + # if count[i][1] >= count[i][0]: + # put(0, 0, i, assert_cnt=False) + # continue + # if i == self.n_stage - 1: + # put(0, 1, i, assert_cnt=False) + # continue + # fa_id = self.get_id(0, 1, i + 1, count[i][1]) + # if 0 <= end_time[fa_id] < cur_time[i + 1]: # TODO + # put(0, 1, i, assert_cnt=False) + # else: + # put(0, 0, i, assert_cnt=False) + + for i in range(self.n_stage): + put(0, 0, i) + for i in range(self.n_stage - 1, -1, -1): + if i == self.n_stage - 1: + put(0, 1, i) + continue + tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost + while ( + mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem + and cur_time[i] + self.fbw_cost[0] <= tmp + and count[i][0] < self.n_micro + ): + for j in range(i + 1): + put(0, 0, j) + put(0, 1, i) + iter_chunk_ = 0 + end_tmp = 0 + for i in range(self.n_stage): + if i == 0: + end_tmp = cur_time[0] + self.fbw_cost[1] + continue + tmp = end_tmp + self.c_cost + while ( + count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1] + or count[i][1] <= count[i - 1][1] < self.n_micro + ): + for j in range(self.n_stage - 1, i - 1, -1): + if count[j][iter_chunk_] < self.n_micro: + put(0, iter_chunk_, j) + iter_chunk_ = 1 - iter_chunk_ + # while mem[i] + self.fbw_mem[0] <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp: + # if iter_chunk_ == 0 and count[i][0] >= count[i - 1][0]: + # break + # for j in range(self.n_stage - 1, i - 1, -1): + # if count[j][iter_chunk_] < self.n_micro: + # put(0, iter_chunk_, j) + # iter_chunk_ = 1 - iter_chunk_ + # end_tmp = max(tmp, cur_time[i]) + self.fbw_cost[1] + + # init_bubble = get_max_stage_bubble() + # print(stage_bubble) + for _ in range(2 * self.n_micro): + # check mem before putting b + for i in range(self.n_stage): + while mem[i] + self.fbw_mem[1] > self.max_mem: + assert len(pending_w[i]) > 0 + put_w(i) + b0_ranks, b1_ranks = [], [] + for i in range(self.n_stage): + if count[i][3] >= count[i][2]: + b0_ranks.append(i) + elif i == self.n_stage - 1: + b1_ranks.append(i) + else: + fa_id = self.get_id(1, 1, i + 1, count[i][3]) + if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro: + b1_ranks.append(i) + else: + b0_ranks.append(i) + b_ranks = [] + # put b1 + for i in reversed(b1_ranks): + b_ranks.append((i, 1)) + # put b0 + for i in b0_ranks: + b_ranks.append((i, 0)) + for i, _chunk_ in b_ranks: + fa_id = -1 + if _chunk_ == 1 and i < self.n_stage - 1: + fa_id = self.get_id(1, 1, i + 1, count[i][3]) + if _chunk_ == 0 and i > 0: + fa_id = self.get_id(1, 0, i - 1, count[i][2]) + while ( + len(pending_w[i]) > 0 + and fa_id >= 0 + and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2] + ): + # fill the bubble + put_w(i) + if ( + len(pending_w[i]) > 0 + and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i] + ): + if _chunk_ == 1: + put_w(i) + elif fill_b: + put_w(i) + put(1, _chunk_, i) + + # put f + for i in range(self.n_stage): + if count[i][1] >= self.n_micro: + continue + put_item = None + if count[i][1] >= count[i][0]: + put_item = 0 + elif i == self.n_stage - 1: + put_item = 1 + else: + if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0: + put_item = 1 + elif count[i][0] < self.n_micro: + if i == 0: + put_item = 0 + elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0: + put_item = 0 + if put_item is None: + continue + # check mem before putting f + while mem[i] + self.fbw_mem[0] > self.max_mem: + assert len(pending_w[i]) > 0 + put_w(i) + fa_id = -1 + if put_item == 0 and i > 0: + fa_id = self.get_id(0, 0, i - 1, count[i][0]) + if put_item == 1 and i < self.n_stage - 1: + fa_id = self.get_id(0, 1, i + 1, count[i][1]) + while ( + len(pending_w[i]) > 0 + and fa_id >= 0 + and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2] + ): + # fill the bubble + put_w(i) + if ( + len(pending_w[i]) > 0 + and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i] + ): + if fill_f: + put_w(i) + put(0, put_item, i) + + for i in range(self.n_stage): + while len(pending_w[i]) > 0: + put_w(i) + + # for i in range(self.n_stage): + # print(stage_str[i]) + + max_bubble = get_max_stage_bubble() + expected_time = sum(self.fbw_cost) * self.n_micro * 2 + max_bubble / expected_time + # print("%6.4f" % bubble_rate, "->", stage_bubble) + if max_approved_bubble < 0 or max_bubble < max_approved_bubble: + _schedule, _end_time, _max_bubble = self.try_v_schedule( + fill_f=fill_f, + fill_b=fill_b, + approved_bubble=stage_bubble, + ) + if _max_bubble < max_bubble: + return _schedule, _end_time, _max_bubble + # print("%2d %3d, [%5d %5d %5d], %6d -> %6.4f %6.4f" % \ + # (self.n_stage, self.n_micro, *self.fbw_cost, self.max_mem // self.f_mem, init_bubble / expected_time, bubble_rate), max_bubble) + return schedule, end_time, max_bubble + + def print_details(self, end_time, print_scaling=1): + for stage in range(self.n_stage): + stage_str = ["."] * int(max(end_time) / print_scaling) + for _cat in range(3): + for _chunk in range(2): + for _micro in range(self.n_micro): + _id = self.get_id(_cat, _chunk, stage, _micro) + if end_time[_id] < 0: + continue + end = int(end_time[_id] / print_scaling) + start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling) + for j in range(start, end): + if j == start or j == end - 1: + stage_str[j] = "FfBbWw"[_cat * 2 + _chunk] + elif j == start + 1: + if _micro >= 10: + stage_str[j] = str(_micro // 10) + else: + stage_str[j] = str(_micro) + elif j == start + 2 and _micro >= 10: + stage_str[j] = str(_micro % 10) + else: + stage_str[j] = "-" + _str = "" + for _c in stage_str: + _str += _c + print(_str) + + def get_v_schedule(self, only_run_time=False): + schedule, end_time, max_bubble = None, None, None + expected_time = sum(self.fbw_cost) * self.n_micro * 2 + for fill_b in [True, False]: + for fill_f in [True, False]: + _schedule, _end_time, _max_bubble = self.try_v_schedule(fill_b=fill_b, fill_f=fill_f) + # print("") + if max_bubble is None or _max_bubble < max_bubble: + max_bubble = _max_bubble + schedule = _schedule + end_time = _end_time + if only_run_time: + return max_bubble + expected_time + # self.print_details(end_time, print_scaling=1) + max_bubble / (expected_time + max_bubble) + # print("%2d %3d, [%5d %5d %5d %5d], %6d -> %6.4f" % \ + # (self.n_stage, self.n_micro, *self.fbw_cost, self.c_cost, self.max_mem // self.f_mem, bubble_rate)) + local_order = [[] for _ in range(self.n_stage)] + comm_id = {} + comm_id_counter = 0 + post_validation_time = 0 + for i in range(self.n_stage - 1, -1, -1): + pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1) + post_validation_time = max( + post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost + ) + # post_validation_time = 0 + # print(i, pv_id, post_validation_time) + for it in ["RECV_", "SEND_", ""]: + if i == 0 and it == "SEND_": + continue + if i == self.n_stage - 1 and it == "RECV_": + continue + # stage_ = i - 1 if it == "RECV_" else i + stage_ = i + local_order[stage_].append( + ScheduledNode( + type=it + "POST_VALIDATION", + chunk=0, + stage=stage_, + minibatch=0, + start_time=post_validation_time, + completion_time=post_validation_time, + ) + ) + comm_id[local_order[stage_][-1]] = comm_id_counter + comm_id_counter += 1 + for i in range(self.n_stage): + for _cat_, _chunk_, _micro_ in schedule[i]: + complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)] + local_order[i].append( + ScheduledNode( + type="FBW"[_cat_], + chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_, + stage=i, + minibatch=_micro_, + start_time=complete_time - self.fbw_cost[_cat_], + completion_time=complete_time, + ) + ) + if _cat_ == 2: # no communication for W + continue + cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD" + + def communicate(send_recv, stage_): + # noinspection PyTypeChecker + local_order[stage_].append( + ScheduledNode( + type=send_recv + cat_str, + chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_, + stage=stage_, + minibatch=_micro_, + start_time=complete_time, + completion_time=complete_time, + ) + ) + comm_id[local_order[stage_][-1]] = comm_id_counter + + if _chunk_ == 1 and i > 0: + communicate("SEND_", i) + communicate("RECV_", i - 1) + if _chunk_ == 0 and i < self.n_stage - 1: + communicate("SEND_", i) + communicate("RECV_", i + 1) + comm_id_counter += 1 + for rank in range(self.n_stage): + # For nodes with the same timestamp on the same stage, communication will be prioritized. + def even_breaker(x: ScheduledNode): + # Compute nodes are always delayed. + if x.type in ["F", "B", "W"]: + return comm_id_counter + # For comm nodes, order by their unique comm id + return comm_id[x] + + local_order[rank] = list(sorted(local_order[rank], key=lambda x: (x.start_time, even_breaker(x)))) + # If a recv with intersects with previous computation, reorder them so that recv + # is executed before computation and hence can be overlapped. + for i in range(len(local_order[rank])): + if ( + i > 0 + and local_order[rank][i - 1].type in {"F", "B", "W"} + and local_order[rank][i].type.startswith("RECV") + and "POST_VALIDATION" not in local_order[rank][i].type + and local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time + ): + local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i] + + local_order_with_rollback = [[] for _ in range(self.n_stage)] + for rank in range(self.n_stage): + rollback_comm = set() + if rank > 0: + for node in local_order[rank - 1]: + if node.type == "POST_VALIDATION": + break + if node.type == "SEND_FORWARD": + assert node.chunk == 0 + rollback_comm.add(node.minibatch) + for node in local_order[rank]: + if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm: + rollback = True + rollback_comm.remove(node.minibatch) + else: + rollback = False + local_order_with_rollback[rank].append( + ScheduledNode( + type=node.type, + chunk=node.chunk, + stage=node.stage, + minibatch=node.minibatch, + start_time=node.start_time, + completion_time=node.completion_time, + rollback=rollback, + ) + ) + assert len(rollback_comm) == 0 + for node in local_order_with_rollback[rank]: + print(f"Rank {rank} Node info {node}") + print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ") + print() + + return local_order_with_rollback diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py new file mode 100644 index 000000000000..0cf9bf67a0a8 --- /dev/null +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -0,0 +1,615 @@ +from functools import partial +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union + +import torch +import torch.cuda +import torch.distributed +from torch.nn import Module, ModuleList +from torch.utils._pytree import tree_map + +from colossalai.accelerator import get_accelerator +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.schedule.v_schedule import ScheduledNode +from colossalai.pipeline.stage_manager import PipelineStageManager + +from ._utils import detach, get_batch_size, get_micro_batch, retain_grad, to_device +from .base import PipelineSchedule + +AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} + + +def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: + if wait_handles is not None: + for req in wait_handles: + req.wait() + + +class ZeroBubbleVPipeScheduler(PipelineSchedule): + def __init__( + self, + stage_manager: PipelineStageManager, + schedule: List[ScheduledNode], + num_model_chunks: int, + num_microbatch: Optional[int] = None, + microbatch_size: Optional[int] = None, + enable_metadata_cache: bool = True, + overlap_p2p: bool = True, + ): + super().__init__(stage_manager) + self.num_microbatch = num_microbatch + self.collect_non_loss_data = None + self.forward_only = None + + self.schedules = schedule + self.it = 0 # curr iteration + self.do_post_validation = False + self.is_first_run = True + self.optimizer = None + self.num_model_chunks = num_model_chunks + + # P2PMeta cache + # self.enable_metadata_cache = enable_metadata_cache + # self.send_tensor_metadata = True + # self.send_grad_metadata = True + # self.tensor_metadata_recv = None + # self.grad_metadata_recv = None + + # P2P communication + self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) + + # init buffer + self._free_buffers() + + def _free_buffers(self): + # free local buffer + # two dim array, first dim is the model chunk, second dim is the microbatch queue + self.input_tensors = [[], []] + self.output_tensors = [[], []] + self.send_forward_buffer = [[], []] + self.recv_forward_buffer = [[], []] + self.send_backward_buffer = [[], []] + self.recv_backward_buffer = [[], []] + self.forward_data_store = [] + self.local_send_forward_buffer = [] + self.local_send_backward_buffer = [] + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + self.batch = batch + self.batch_size = get_batch_size(batch) + + if self.microbatch_size is None: + assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" + self.microbatch_size = self.batch_size // self.num_microbatch + if self.num_microbatch is None: + assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" + self.num_microbatch = self.batch_size // self.microbatch_size + + if not self.forward_only: + assert self.last_batch_size is None or self.last_batch_size == self.batch_size + assert self.batch_size == self.microbatch_size * self.num_microbatch + + assert ( + self.num_microbatch % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" + + if self.forward_only: + self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 + # NOTE: disable metadata cache when batch size changes (not valid anymore) + # if self.batch_size != self.last_batch_size: + # self.enable_metadata_cache = False + # self.send_tensor_metadata = True + # self.send_grad_metadata = True + # self.tensor_metadata_recv = None + # self.grad_metadata_recv = None + + self.last_batch_size = self.batch_size + + def load_micro_batch(self, model_chunk_id: int) -> Any: + """Load a micro batch from the current batch. + + Args: + microbatch_id (int): the current model chunk idx. + + Returns: + Any: Micro batch. + """ + assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted" + micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) + + def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: + """Helper method to get the model chunk ID given the iteration number. + + Args: + microbatch_id (int): the current microbatch idx + forward (bool): if is the forward process + + Returns: + int: The model chunk idx of the input microbatch_id + """ + assert ( + microbatch_id < self.num_microbatch * self.num_model_chunks + ), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})" + microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks) + model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages + if not is_forward: + # Reverse order + model_chunk_id = self.num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + Any: The wait handles for the communication. + """ + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + ################ + # chunk = 0 & is_first_stage + # do nothing; cause u are chunk 0 in first rank, u have no prev rank; + ################# + if self.stage_manager.is_first_stage(ignore_chunk=True): + return None, [] + + ################ + # chunk = 0 & not is_first_stage + # Recv y from PREV_rank as input + ################# + else: + prev_rank = self.stage_manager.get_prev_rank() + input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank) + # metadata_recv=self.tensor_metadata_recv + # if self.enable_metadata_cache and self.tensor_metadata_recv is None: + # self.tensor_metadata_recv = create_send_metadata(input_tensor) + return input_tensor, wait_handles + + else: + ################ + # chunk = 1 & is_last_stage + # get y from local_send_forward_buffer as input + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + input_tensor = self.local_send_forward_buffer.pop(0) + + # if self.enable_metadata_cache and self.tensor_metadata_recv is None: + # self.tensor_metadata_recv = create_send_metadata(input_tensor) + + return input_tensor, [] + + ################ + # chunk = 1 & not is_last_stage + # recv y from NEXT_rank as input + ################ + else: + next_rank = self.stage_manager.get_next_rank() + input_tensor, wait_handles = self.comm.recv_forward(next_rank) + + # metadata_recv=self.tensor_metadata_recv + # if self.enable_metadata_cache and self.tensor_metadata_recv is None: + # self.tensor_metadata_recv = create_send_metadata(input_tensor) + + return input_tensor, wait_handles + + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + Any: The wait handles for the communication. + """ + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + # bwd chunk0 is right V; + ################ + # chunk = 0 & is_last_stage + # get dy from local recv_bwd_buffer + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + output_tensor_grad = self.local_send_backward_buffer.pop(0) + # if self.enable_metadata_cache and self.grad_metadata_recv is None: + # self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + return output_tensor_grad, [] + + ################ + # chunk = 0 & not is_last_stage + # Recv bwd from next stage; + ################ + else: + next_rank = self.stage_manager.get_next_rank() + output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank) + # metadata_recv=self.grad_metadata_recv + # if self.enable_metadata_cache and self.grad_metadata_recv is None: + # self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + return output_tensor_grad, wait_handles + + else: + # bwd chunk1 is left V; + ################ + # chunk = 1 & is_first_stage + # do nothing; get loss from local + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + return None, [] + + ################ + # chunk = 1 & not is_first_stage + # self.comm.recv_backward recv bwd from prev stage; + ################ + else: + + prev_rank = self.stage_manager.get_prev_rank() + output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank) + + # metadata_recv=self.grad_metadata_recv + # if self.enable_metadata_cache and self.grad_metadata_recv is None: + # self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + return output_tensor_grad, wait_handles + + def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List: + """Sends the input tensor to the next stage in pipeline. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + + Returns: + Any: The wait handles for the communication. + """ + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + ################ + # chunk = 0 && is_last_stage + # hold y on local_send_forward_buffer + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_forward_buffer.append(output_tensor) + return [] + + ################ + # chunk = 0 && not is_last_stage + # self.comm.send_forward send y to NEXT stage + ################ + else: + next_rank = self.stage_manager.get_next_rank() + send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) + # send_metadata=self.send_tensor_metadata + # self.send_tensor_metadata = not self.enable_metadata_cache + return send_handles + + else: + ################ + # chunk = 1 && is_first_stage + # do nothing; cause you are the last chunk on last stage; + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 1 && not is_first_stage + # self.comm.send_forward send y to PREV stage + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + send_handles = self.comm.send_forward(output_tensor, prev_rank) + # send_metadata=self.send_tensor_metadata + # self.send_tensor_metadata = not self.enable_metadata_cache + return send_handles + + def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List: + """Sends the gradient tensor to the previous stage in pipeline. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + + Returns: + Any: The wait handles for the communication. + """ + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + # bwd chunk0 is right V; + ################ + # chunk = 0 && is_first_stage + # do nothing; cause u are the first chunk in first stage; bwd end + # send input_tensor_grad to local buffer; + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 0 && not is_first_stage + # Send dx to PREV stage; + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + send_handles = self.comm.send_backward(input_tensor_grad, prev_rank) + # send_metadata=self.send_grad_metadata + return send_handles + + # bwd chunk1 is left V; + else: + ################ + # chunk = 1 && is_last_stage + # hold dy to local_send_bwd_buffer; + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_backward_buffer.append(input_tensor_grad) + return [] + + ################ + # chunk = 1 && not is_last_stage + # Send dx to NEXT stage; + ################ + else: + next_rank = self.stage_manager.get_next_rank() + # print(f"send bwd input_tensor_grad {input_tensor_grad}") + send_handles = self.comm.send_backward(input_tensor_grad, next_rank) + # send_metadata=self.send_grad_metadata + return send_handles + + def forward_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + Args: + model (ModuleList or Module): Model Chunk to be run + input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. + criterion (Callable): Criterion to calculate loss. + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + # Load input ids, attention mask and labels + # micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + + # for the first stage, input_obj is None + # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. + # Only attention_mask from micro_batch is used + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + output_obj = model_chunk[model_chunk_id](input_obj) + # last layer in model + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + loss = criterion(output_obj) / self.num_microbatch + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj + + def backward_b_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + # optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: + """Backward one step of the pipeline + + Args: + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None. + output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor). + output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None. + + Returns: + Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None. + """ + # calculate bwd b step ; only dx = w*dy; + + # Retain the grad on the input_obj. + tree_map(retain_grad, input_obj) + + if model_chunk_id == 0: + # bwd step + torch.autograd.backward( + tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True + ) + else: + if self.stage_manager.is_first_stage(ignore_chunk=True): + # loss backward; output_obj is loss + torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True) + else: + # commom bwd step + # print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}") + # BUG:output_obj_grad is None + torch.autograd.backward( + tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True + ) + + return input_obj.grad + + def backward_w_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + # optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ): + # calculate bwd w step ; only dw = x*dy; + if model_chunk_id == 0: + torch.autograd.backward( + tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()) + ) + + else: + if self.stage_manager.is_first_stage(ignore_chunk=True): + torch.autograd.backward(output_obj_grad, inputs=list(model=model_chunk[model_chunk_id].parameters())) + + else: + torch.autograd.backward( + tensors=output_obj, + grad_tensors=output_obj_grad, + inputs=list(model_chunk[model_chunk_id].parameters()), + ) + + def schedule_f( + self, + scheduled_node, + model_chunk: torch.nn.ModuleList, + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ): + # Step1: recv fwd + if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): + # first layer + input_obj = input_obj + else: + # other layer + input_obj, wait_handles = self.recv_forward(model_chunk_id) + # print(f"recv input_obj {input_obj}") + _wait_p2p(wait_handles) + # Step2: fwd step + output_obj = self.forward_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + input_obj=input_obj, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + # print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}") + + # add input and output object for backward + self.input_tensors[model_chunk_id].append(input_obj) + self.output_tensors[model_chunk_id].append(output_obj) + + # Step3: send fwd + send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj) + + def schedule_b( + self, + scheduled_node, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + # optimizer: OptimizerWrapper, + # input_obj: Optional[dict], + # output_obj: Union[dict, torch.Tensor], + # output_obj_grad: Optional[dict], + ): + # Step1: recv bwd + # not first stage and chunk 1 + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + output_tensor_grad, recv_bwd_handles = None, [] + # print(f"recv output_tensor_grad {output_tensor_grad}") + else: + output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id) + # print(f"recv output_tensor_grad {output_tensor_grad}") + + # get input and output object from buffer + input_obj = self.input_tensors[model_chunk_id].pop() + output_obj = self.output_tensors[model_chunk_id].pop() + + _wait_p2p(recv_bwd_handles) + # print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}") + # Step2: bwd step + input_object_grad = self.backward_b_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + # optimizer: OptimizerWrapper, + input_obj=input_obj, + output_obj=output_obj, + output_obj_grad=output_tensor_grad, + ) + print(f"input_object_grad {input_object_grad}") + + # Step3: send bwd + send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad) + + def schedule_w( + self, + scheduled_node, + non_w_pending, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + # optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ): + self.backward_w_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + # optimizer: OptimizerWrapper, + input_obj=input_obj, + output_obj=output_obj, + output_obj_grad=output_obj_grad, + ) + + def run_forward_backward( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ): + it = self.it + # while we still have schedules_node in self.schedules + while it < len(self.schedules): + scheduled_node = self.schedules[it] + if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: + # communication + if scheduled_node.type == "RECV_FORWARD": + self.recv_forward() + elif scheduled_node.type == "RECV_BACKWARD": + self.recv_backward() + elif scheduled_node.type == "SEND_FORWARD": + self.send_forward() + elif scheduled_node.type == "SEND_BACKWARD": + self.send_backward() + elif scheduled_node.type == "F": + self.schedule_f() + elif scheduled_node.type == "B": + self.schedule_b() + elif scheduled_node.type == "W": + self.schedule_w() diff --git a/tests/test_pipeline/test_schedule/test_dx_dw.py b/tests/test_pipeline/test_schedule/test_dx_dw.py new file mode 100644 index 000000000000..6da1434d83e6 --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_dx_dw.py @@ -0,0 +1,1200 @@ +import gc +from copy import deepcopy +from typing import Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + +IN_DIM = 8192 +OUT_DIM = 8192 +NUM_LAYER = 3 + + +class MlpModel(nn.Module): + def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER): + super().__init__() + self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: + num_params = 0 + num_params_trainable = 0 + for p in model.parameters(): + num_params += p.numel() + if p.requires_grad: + num_params_trainable += p.numel() + return num_params, num_params_trainable + + +# Step1: dx = w*dy +def backward_b(loss, x, model): + print(f"Before bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB") + # print(f"Before x grad {x.grad}") + # for name, param in model.named_parameters(): + # print(f"Before bwd b \n param {param}\n param gard {param.grad}\n") + + torch.autograd.backward(loss, inputs=x, retain_graph=True) + + # for name, param in model.named_parameters(): + # print(f"After bwd b \n param {param}\n param gard {param.grad}\n") + + # print(f"After x grad {x.grad}") + print(f"After bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + +# Step1: dx = w*dy; for layer not last +def backward_b_not_last(tensors, grad, x, model): + print(f"Before bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB") + torch.autograd.backward(tensors=tensors, grad_tensors=grad, inputs=x, retain_graph=True) + print(f"After bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + +def backward_w(loss, model): + print(f"Before bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + # for name, param in model.named_parameters(): + # print(f"Before bwd w \n param {param}\n param gard {param.grad}\n") + + torch.autograd.backward(loss, inputs=list(model.parameters())) + + # for name, param in model.named_parameters(): + # print(f"After bwd w \n param {param}\n param gard {param.grad}\n") + + print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + +# Step2: dummy dw = x*dy +def backward_w_not_last(tensors, grad, model): + print(f"Before bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + torch.autograd.backward(tensors=tensors, grad_tensors=grad, inputs=list(model.parameters())) + print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + +def test_dx_dw_split(): + device = "cuda:0" + model = nn.Linear(8, 8, bias=None).to(device=device) + print(f"model numel {get_model_numel(model)}") # 4GB + x = torch.rand(8, 8).to(device=device) + ref_model = deepcopy(model) + ref_x = x.clone() + + # first step + x.requires_grad_() + loss = model(x).sum() + backward_b(loss, x, model) + for p in model.parameters(): + assert p.grad is None + assert x.grad is not None + backward_w(loss, model) + for p in model.parameters(): + assert p.grad is not None + + # # second step + # loss = model(x).sum() + # backward_b(loss, x, model) + # backward_w(loss, model) + + ref_x.requires_grad_() + ref_loss = ref_model(ref_x).sum() + ref_loss.backward() + + assert torch.equal(x.grad, ref_x.grad) + for p1, p2 in zip(model.parameters(), ref_model.parameters()): + assert torch.equal(p1.grad, p2.grad) + + +def test_double_dx_dw_split_nsync(): + device = "cuda:0" + model = nn.Linear(8, 8, bias=None).to(device=device) + # print(f"model numel {get_model_numel(model)}") # 4GB + x1 = torch.rand(8, 8).to(device=device) + x2 = torch.rand(8, 8).to(device=device) + ref_model = deepcopy(model) + ref_x1 = x1.clone() + ref_x2 = x2.clone() + + # first step + x1.requires_grad_() + x2.requires_grad_() + ref_x1.requires_grad_() + ref_x2.requires_grad_() + + # loss for dx_dw bwd + loss1 = model(x1).sum() + loss2 = model(x2).sum() + + # loss for common bwd + ref_loss1 = ref_model(ref_x1).sum() + ref_loss2 = ref_model(ref_x2).sum() + + # dx1 + backward_b(loss1, x1, model) + for p in model.parameters(): + assert p.grad is None + assert x1.grad is not None + + # dx2 + backward_b(loss2, x2, model) + + # dw1 + backward_w(loss1, model) + for p in model.parameters(): + assert p.grad is not None + + # common bwd 1 + ref_loss1.backward() + + # assert dx1 & dw1 == bwd 1 + assert_close(x1.grad, ref_x1.grad) + for p1, p2 in zip(model.parameters(), ref_model.parameters()): + assert_close(p1, p2) + assert_close(p1.grad, p2.grad) + + # dw2 + backward_w(loss2, model) + + # common bwd 2 + ref_loss2.backward() + + # assert dx2 & dw2 == bwd 2 + assert_close(x2.grad, ref_x2.grad) + for p1, p2 in zip(model.parameters(), ref_model.parameters()): + print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") + assert_close(p1, p2) + assert_close(p1.grad, p2.grad) + + +def test_double_dx_dw_split_sync(): + device = "cuda:0" + model = nn.Linear(8, 8, bias=None).to(device=device) + # print(f"model numel {get_model_numel(model)}") # 4GB + x1 = torch.rand(8, 8).to(device=device) + x2 = torch.rand(8, 8).to(device=device) + + # x1 = torch.ones(8, 8).to(device=device) + # x2 = torch.ones(8, 8).to(device=device) + + ref_model = deepcopy(model) + ref_x1 = x1.clone() + ref_x2 = x2.clone() + + x1.requires_grad_() + x2.requires_grad_() + ref_x1.requires_grad_() + ref_x2.requires_grad_() + + ############ + # step1: + ############ + print(f"Step1\n") + + # loss1 + loss1 = model(x1).sum() + + # ref_loss1 + ref_loss1 = ref_model(ref_x1).sum() + + # dx1 + backward_b(loss1, x1, model) + for p in model.parameters(): + assert p.grad is None + assert x1.grad is not None + + # dw1 + backward_w(loss1, model) + for p in model.parameters(): + assert p.grad is not None + + # common bwd 1 + ref_loss1.backward() + + # assert dx1 & dw1 == bwd 1 + assert_close(x1.grad, ref_x1.grad) + for p1, p2 in zip(model.parameters(), ref_model.parameters()): + assert_close(p1, p2) + assert_close(p1.grad, p2.grad) + + ############ + # step2: + ############ + print(f"Step2\n") + + # loss2 + loss2 = model(x2).sum() + + # ref_loss2 + ref_loss2 = ref_model(ref_x2).sum() + + for p1, p2 in zip(model.parameters(), ref_model.parameters()): + # print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") + assert_close(p1, p2) + assert_close(p1.grad, p2.grad) + + # dx2 + backward_b(loss2, x2, model) + + # dw2 + backward_w(loss2, model) + + # common bwd 2 + ref_loss2.backward() + + # assert dx2 & dw2 == bwd 2 + assert_close(x2.grad, ref_x2.grad) + for p1, p2 in zip(model.parameters(), ref_model.parameters()): + # print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") + assert_close(p1, p2) + assert_close(p1.grad, p2.grad) + + +def deallocate_output_tensor(out): + """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. + + This method should be called right after the output tensor has been + sent to the next pipeline stage. At this point, the output tensor is + only useful for its '.grad_fn' field, and not its '.data'. + """ + assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ + assert out._base is None, "counter-productive to free a view of another tensor." + out.data = torch.empty( + (1,), + device=out.device, + dtype=out.dtype, + ) + + +# del loss and x +def mem_dx_dw(): + device = "cuda:0" + # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + model = MlpModel().to(device=device) + print(f"model numel {get_model_numel(model)}") # 4GB + print(f"After init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + print(f"Before init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + + x1.requires_grad_() + x2.requires_grad_() + x3.requires_grad_() + print(f"After init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ############ + # step1: + ############ + print(f"\nStep1") + + # loss1 + print(f"Before Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + loss1 = model(x1).sum() + print(f"After Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + print(f"Before loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + print(f"After loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + # dx1 + backward_b(loss1, x1, model) + + # dw1 + backward_w(loss1, model) + + # deallocate_output_tensor(x1) + # deallocate_output_tensor(loss1) + del loss1, x1 + # del x1 + # del y1 + print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ############ + # step2: + ############ + print(f"\nStep2") + + # loss2 + loss2 = model(x2).sum() + + # dx2 + backward_b(loss2, x2, model) + + # dw2 + backward_w(loss2, model) + + # deallocate_output_tensor(x2) + # deallocate_output_tensor(loss2) + del x2, loss2 + # del x2 + # del y2 + print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ############ + # step3: + ############ + print(f"\nStep3") + + # loss3 + loss3 = model(x3).sum() + + # dx2 + backward_b(loss3, x3, model) + + # dw2 + backward_w(loss3, model) + + # deallocate_output_tensor(x3) + # deallocate_output_tensor(loss3) + # del x3 + # del y3 + del x3, loss3 + + print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + param_ids = [id(p) for p in model.parameters()] + for obj in gc.get_objects(): + if torch.is_tensor(obj) and id(obj) not in param_ids: + print(obj) + + +# del activation +def activation_dx_dw(): + device = "cuda:0" + # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + model = MlpModel().to(device=device) + x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + + x1.requires_grad_() + x2.requires_grad_() + x3.requires_grad_() + print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + # activations = {} + # def register_hooks(module): + # def activation_hook(module, input, output): + # activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach() + # def bwd_hook(module, grad_input, grad_output): + # del activations[f"{module.__class__.__name__}_{id(module)}"] + # module.register_forward_hook(activation_hook) + # module.register_backward_hook(bwd_hook) + + # model.apply(register_hooks) + + ############ + # step1: + ############ + print(f"\nStep1") + + # loss1 + output1 = model(x1) + loss1 = output1.sum() + + # dx1 + backward_b(loss1, x1, model) + + # for name, p in model.named_parameters(): + # print(f"p grad {p.grad}") + + # dw1 + backward_w(loss1, model) + + # for name, p in model.named_parameters(): + # del p.grad + + # del loss1, x1 + del loss1, x1, output1 + print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ############ + # step2: + ############ + print(f"\nStep2") + + # loss2 + output2 = model(x2) + loss2 = output2.sum() + + # dx2 + backward_b(loss2, x2, model) + + # for name, p in model.named_parameters(): + # print(f"p grad {p.grad}") + + # dw2 + backward_w(loss2, model) + + # for name, p in model.named_parameters(): + # print(f"p grad {p.grad}") + + # del x2, loss2 + del x2, loss2, output2 + print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ############ + # step3: + ############ + print(f"\nStep3") + + # loss3 + output3 = model(x3) + loss3 = output3.sum() + + # dx2 + backward_b(loss3, x3, model) + + # dw2 + backward_w(loss3, model) + + # del x3, loss3 + del x3, loss3, output3 + + print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + +def model_chunk_dx_dw(): + device = "cuda:0" + num_layers = 4 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(device=device) + input = torch.rand(4096, 4096, requires_grad=True).to(device=device) + + input_base = input.clone() + + model_base = deepcopy(model) + + ########################## + # Fwd bwd for dx dw + ########################## + + model_chunk_0 = torch.nn.Sequential() # for layer 1 & 2 + model_chunk_1 = torch.nn.Sequential() # for layer 3 & 4 + + for idx, sub_model in enumerate(model.layers): + if idx < 2: + model_chunk_0.append(sub_model) + else: + model_chunk_1.append(sub_model) + + print(f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # Step1:chunk 0 fwd + ########################## + output1 = model_chunk_0(input) + + # detach output1; then output1 for chunk 0, output1_dt for chunk 1; + output1_dt = output1.detach() + output1_dt.requires_grad_() + print(f"After chunk0 fwd (include detach output1): {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # Step2:chunk 1 fwd + ########################## + output2 = model_chunk_1(output1_dt) + + print(f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # Step3:chunk 1 bwd b: dx=w*dy & bwd w:dw=x*dy + ########################## + loss = output2.mean() + backward_b(loss, output1_dt, model_chunk_1) + backward_w(loss, model_chunk_1) + + print(f"After chunk1 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # Step4:chunk 0 bwd b: dx=w*dy & bwd w:dw=x*dy + ########################## + # dx = w*dy + backward_b_not_last(tensors=output1, grad=output1_dt.grad, x=input, model=model_chunk_0) + backward_w_not_last(tensors=output1, grad=output1_dt.grad, model=model_chunk_0) + + print(f"After chunk0 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # Fwd bwd for base + ########################## + + # fwd & bwd + output_base = model_base(input_base) + + loss_base = output_base.mean() + + loss_base.backward() + print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # Assert param + ########################## + + assert_close(output2, output_base) + assert_close(output2.grad, output_base.grad) + + for p1, p2 in zip(model.parameters(), model_base.parameters()): + assert_close(p1, p2) + assert_close(p1.grad, p2.grad) + + del output1, output1_dt, output2, loss, loss_base, output_base + print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + +def model_chunk_dx_dw_communication( + rank: int, + world_size: int, + port: int, +): + # init dist + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + pg_mesh = ProcessGroupMesh(world_size) + stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=2) + rank = dist.get_rank() + comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False) + + print(f"{stage_manager.get_rank()}") + + # init model and input + num_layers = 4 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(rank) + input = torch.rand(4096, 4096, requires_grad=True).to(rank) + + input_base = input.clone() + model_base = deepcopy(model) + + if rank == 0: + model_chunk_0 = torch.nn.Sequential().to(rank) # for layer 1 & 2 on rank0 + for idx, sub_model in enumerate(model.layers): + if idx < 2: + model_chunk_0.append(sub_model) + else: + model_chunk_1 = torch.nn.Sequential().to(rank) # for layer 3 & 4 on rank1 + for idx, sub_model in enumerate(model.layers): + if idx >= 2: + model_chunk_1.append(sub_model) + + print( + f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + ########################## + # Step1:chunk 0 fwd + ########################## + if rank == 0: + output1 = model_chunk_0(input) + # detach output1; then output1 for chunk 0, output1_dt for chunk 1; + # output1_dt_rank0 = output1.detach() + # output1_dt_rank0.requires_grad_() + print( + f"After chunk0 fwd (include detach output1): {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + # send y(output1_dt) to next stage + comm.send_forward(output1, stage_manager.get_next_rank()) + + ########################## + # Step2:chunk 1 fwd + ########################## + if rank == 1: + # recv y(output1_dt) from prev stage + output1_dt_rank1, wait_handles = comm.recv_forward(stage_manager.get_prev_rank()) + output1_dt_rank1.requires_grad_() + output2 = model_chunk_1(output1_dt_rank1) + + print( + f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + ########################## + # Step3:chunk 1 on device_1 bwd b: dx=w*dy & bwd w:dw=x*dy + ########################## + if rank == 1: + loss = output2.mean() + backward_b(loss, output1_dt_rank1, model_chunk_1) + backward_w(loss, model_chunk_1) + + print(f"After chunk1 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + # send bwd output1_dt_rank1 from rank1 to rank 0 + comm.send_backward(output1_dt_rank1.grad, stage_manager.get_prev_rank()) + ########################## + # Step4:chunk 0 on device_0 bwd b: dx=w*dy & bwd w:dw=x*dy + ########################## + + if rank == 0: + # recv bwd output1_dt_rank1 from rank1 to rank 0 + output1_dt_rank0_grad, _ = comm.recv_backward(stage_manager.get_next_rank()) + + backward_b_not_last(tensors=output1, grad=output1_dt_rank0_grad, x=input, model=model_chunk_0) + backward_w_not_last(tensors=output1, grad=output1_dt_rank0_grad, model=model_chunk_0) + + print(f"After chunk0 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base) + loss_base = output_base.mean() + loss_base.backward() + print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # Assert param + ########################## + # assert output + if rank == 1: + assert_close(output2, output_base) + assert_close(output2.grad, output_base.grad) + + # assert model param & grad + if rank == 0: + count = 0 + for (chunk_name, chunk_param), (base_name, base_param) in zip( + model_chunk_0.named_parameters(), model_base.named_parameters() + ): + if count < 2: + assert_close(chunk_param, base_param) + assert_close(chunk_param.grad, base_param.grad) + count += 1 + if rank == 1: + count = 0 + for (chunk_name, chunk_param), (base_name, base_param) in zip( + model_chunk_1.named_parameters(), model_base.named_parameters() + ): + if count >= 2: + assert_close(chunk_param, base_param) + assert_close(chunk_param.grad, base_param.grad) + count += 1 + # clean memory + if rank == 0: + del output1, output1_dt_rank0_grad + if rank == 1: + del output2, loss, output1_dt_rank1 + del loss_base, output_base + print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + + +# Return: output, loss +def schedule_f( + stage_manager: PipelineStageManager, + comm: PipelineP2PCommunication, + input: torch.Tensor, + model_chunk: torch.nn.ModuleList, + model_chunk_id: int, +): + # chunk_id == 0 + if model_chunk_id == 0: + # recv fwd from prev + if stage_manager.is_first_stage(ignore_chunk=True): + input = input # get local input + else: + prev_rank = stage_manager.get_prev_rank() + input, wait_handles = comm.recv_forward(prev_rank) + + # fwd step + output = model_chunk[model_chunk_id](input) + + # send fwd to next + if stage_manager.is_last_stage(ignore_chunk=True): + return input, output, None # return local output + else: + next_rank = stage_manager.get_next_rank() + comm.send_forward(output, next_rank) + + # chunk_id == 1 + if model_chunk_id == 1: + # recv fwd from next + if stage_manager.is_last_stage(ignore_chunk=True): + input = input # get local input + else: + next_rank = stage_manager.get_next_rank() + input, wait_handles = comm.recv_forward(next_rank) + + # fwd step + output = model_chunk[model_chunk_id](input) + + # send fwd to prev + if stage_manager.is_first_stage(ignore_chunk=True): + loss = output.mean() + return input, output, loss # return local output + else: + prev_rank = stage_manager.get_prev_rank() + comm.send_forward(output, prev_rank) + return input, output, None + + +def schedule_b( + stage_manager: PipelineStageManager, + comm: PipelineP2PCommunication, + input: torch.Tensor, # x + output: torch.Tensor, # y + output_grad: torch.Tensor, # dy + model_chunk: torch.nn.ModuleList, + model_chunk_id: int, +): + # chunk_id == 0 + if model_chunk_id == 0: + + # recv bwd from next + if stage_manager.is_last_stage(ignore_chunk=True): + output_grad = output_grad # get dy from local + else: + next_rank = stage_manager.get_next_rank() + output_grad, _ = comm.recv_backward(next_rank) + + # bwd step + backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id]) + + backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id]) + + # send bwd to prev + if stage_manager.is_first_stage(ignore_chunk=True): + return input.grad + else: + prev_rank = stage_manager.get_prev_rank() + comm.send_backward(input.grad, prev_rank) + + # chunk_id == 1 + if model_chunk_id == 1: + # recv bwd from prev + if stage_manager.is_first_stage(ignore_chunk=True): + output_grad = output_grad + else: + prev_rank = stage_manager.get_prev_rank() + # print(f"prev_rank {prev_rank} curr rank {stage_manager.get_rank()}") + output_grad, _ = comm.recv_backward(next_rank=prev_rank) + + # bwd step + # print(f"Before input grad {input.grad}") + # for name, param in model_chunk[model_chunk_id].named_parameters(): + # print(f"Before {name} grad {param.grad}") + + if stage_manager.is_first_stage(ignore_chunk=True): + backward_b(loss=output_grad, x=input, model=model_chunk[model_chunk_id]) + backward_w(loss=output_grad, model=model_chunk[model_chunk_id]) + else: + # commom bwd step + # print(f"output_grad {output_grad}") + backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id]) + backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id]) + + # print(f"After input grad {input.grad}") + # for name, param in model_chunk[model_chunk_id].named_parameters(): + # print(f"After {name} grad {param.grad}") + + # send bwd to next + if stage_manager.is_last_stage(ignore_chunk=True): + return input.grad + else: + next_rank = stage_manager.get_next_rank() + comm.send_backward(input.grad, next_rank) + + return input.grad + + +def schedule_w(): + pass + + +def model_chunk_dx_dw_comm_interleaved( + rank: int, + world_size: int, + port: int, +): + # init dist + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + pg_mesh = ProcessGroupMesh(world_size) + stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=world_size) + rank = dist.get_rank() + comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False) + + # init model and input + num_layers = 8 + in_dim = out_dim = 2048 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) + + input_base = input0.clone() + model_base = deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + chunk_0 = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + chunk_0.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + chunk_1 = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + chunk_1.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + chunk_2 = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + chunk_2.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + chunk_3 = torch.nn.Sequential().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + chunk_3.append(sub_model) + + # # test checkpoint + # check_fn = lambda submodule: isinstance(submodule, (Linear)) + # non_reentrant_wrapper = partial( + # checkpoint_wrapper, + # # checkpoint_impl=CheckpointImpl.NO_REENTRANT, + # checkpoint_impl=CheckpointImpl.REENTRANT, + # ) + # apply_activation_checkpointing( + # model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + # ) + + print( + f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + # set_checkpoint_early_stop(False) + # buffer use to save input and output + + ########################## + # Step1: fwd + ########################## + ###### + # fwd 1->4 + ###### + # chunk 0 id 0 (layer 0) fwd + if rank == 0: + chunk_id = 0 + input0, output0, _ = schedule_f( + stage_manager=stage_manager, + comm=comm, + input=input0, + model_chunk=chunk_0, + model_chunk_id=chunk_id, + ) + print( + f"chunk 0 id 0 (layer 0)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + # chunk 1 id 0 (layer 1) fwd + if rank == 1: + chunk_id = 0 + input1, output1, _ = schedule_f( + stage_manager=stage_manager, + comm=comm, + input=None, + model_chunk=chunk_1, + model_chunk_id=chunk_id, + ) + print( + f"chunk 1 id 0 (layer 1)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + # chunk 2 id 0 (layer 2) fwd + if rank == 2: + chunk_id = 0 + input2, output2, _ = schedule_f( + stage_manager=stage_manager, + comm=comm, + input=None, + model_chunk=chunk_2, + model_chunk_id=chunk_id, + ) + print( + f"chunk 2 id 0 (layer 2)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + # chunk 3 id 0 (layer 3) fwd + if rank == 3: + chunk_id = 0 + input3, output3, _ = schedule_f( + stage_manager=stage_manager, + comm=comm, + input=None, + model_chunk=chunk_3, + model_chunk_id=chunk_id, + ) + print( + f"chunk 3 id 0 (layer 3)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + ###### + # fwd 4->1 + ###### + + if rank == 3: + chunk_id = 1 + input4, output4, _ = schedule_f( + stage_manager=stage_manager, + comm=comm, + input=output3, + model_chunk=chunk_3, + model_chunk_id=chunk_id, + ) + print( + f"chunk 3 id 1 (layer 4)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + if rank == 2: + chunk_id = 1 + input5, output5, _ = schedule_f( + stage_manager=stage_manager, + comm=comm, + input=None, + model_chunk=chunk_2, + model_chunk_id=chunk_id, + ) + print( + f"chunk 2 id 1 (layer 5)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + if rank == 1: + chunk_id = 1 + input6, output6, _ = schedule_f( + stage_manager=stage_manager, + comm=comm, + input=None, + model_chunk=chunk_1, + model_chunk_id=chunk_id, + ) + print( + f"chunk 1 id 1 (layer 6)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + if rank == 0: + chunk_id = 1 + input7, output7, loss = schedule_f( + stage_manager=stage_manager, + comm=comm, + input=None, + model_chunk=chunk_0, + model_chunk_id=chunk_id, + ) + # print(f"fwd output {output7}") + print( + f"chunk 0 id 1 (layer 7)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + ########################## + # Step2: bwd + ########################## + ###### + # bwd rank 4->1 + ###### + # chunk 0 id 1 (layer 7) bwd + if rank == 0: + chunk_id = 1 + input_grad7 = schedule_b( + stage_manager=stage_manager, + comm=comm, + input=input7, # x + output=output7, # y + output_grad=loss, # dy + model_chunk=chunk_0, + model_chunk_id=chunk_id, + ) + + # # chunk 1 id 1 (layer 6) bwd + if rank == 1: + chunk_id = 1 + input_grad6 = schedule_b( + stage_manager=stage_manager, + comm=comm, + input=input6, # x + output=output6, # y + output_grad=None, # dy + model_chunk=chunk_1, + model_chunk_id=chunk_id, + ) + + # chunk 2 id 1 (layer 5) bwd + if rank == 2: + chunk_id = 1 + input_grad5 = schedule_b( + stage_manager=stage_manager, + comm=comm, + input=input5, # x + output=output5, # y + output_grad=None, # dy + model_chunk=chunk_2, + model_chunk_id=chunk_id, + ) + + # chunk 3 id 1 (layer 4) bwd + if rank == 3: + chunk_id = 1 + input_grad4 = schedule_b( + stage_manager=stage_manager, + comm=comm, + input=input4, # x + output=output4, # y + output_grad=None, # dy + model_chunk=chunk_3, + model_chunk_id=chunk_id, + ) + # print(f"input_grad4 {input_grad4}") + + ###### + # bwd rank 1->4 + ###### + + # chunk 3 id 0 (layer 3) bwd + if rank == 3: + chunk_id = 0 + input_grad3 = schedule_b( + stage_manager=stage_manager, + comm=comm, + input=input3, # x + output=output3, # y + output_grad=input_grad4, # dy + model_chunk=chunk_3, + model_chunk_id=chunk_id, + ) + # print(f"input_grad3 {input_grad3}") + + # chunk 2 id 0 (layer 2) bwd + if rank == 2: + chunk_id = 0 + input_grad2 = schedule_b( + stage_manager=stage_manager, + comm=comm, + input=input2, # x + output=output2, # y + output_grad=None, # dy + model_chunk=chunk_2, + model_chunk_id=chunk_id, + ) + # print(f"input_grad2 {input_grad2}") + + # chunk 1 id 0 (layer 1) bwd + if rank == 1: + chunk_id = 0 + input_grad1 = schedule_b( + stage_manager=stage_manager, + comm=comm, + input=input1, # x + output=output1, # y + output_grad=None, # dy + model_chunk=chunk_1, + model_chunk_id=chunk_id, + ) + + # chunk 0 id 0 (layer 0) bwd + if rank == 0: + chunk_id = 0 + input_grad0 = schedule_b( + stage_manager=stage_manager, + comm=comm, + input=input0, # x + output=output0, # y + output_grad=None, # dy + model_chunk=chunk_0, + model_chunk_id=chunk_id, + ) + # print(f"input_grad0 {input_grad0}") + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base) + loss_base = output_base.mean() + loss_base.backward() + print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # Assert close + ########################## + # assert output + if rank == 0: + assert_close(output7, output_base) + + # assert weight + if rank == 0: + # layer 0 + assert_close(chunk_0[0].weight, model_base.layers[0].weight) + assert_close(chunk_0[0].weight.grad, model_base.layers[0].weight.grad) + # layer 7 + assert_close(chunk_0[1].weight, model_base.layers[7].weight) + assert_close(chunk_0[1].weight.grad, model_base.layers[7].weight.grad) + if rank == 1: + # layer 1 + assert_close(chunk_1[0].weight, model_base.layers[1].weight) + assert_close(chunk_1[0].weight.grad, model_base.layers[1].weight.grad) + # layer 6 + assert_close(chunk_1[1].weight, model_base.layers[6].weight) + assert_close(chunk_1[1].weight.grad, model_base.layers[6].weight.grad) + + if rank == 2: + # layer 2 + assert_close(chunk_2[0].weight, model_base.layers[2].weight) + assert_close(chunk_2[0].weight.grad, model_base.layers[2].weight.grad) + # layer 5 + assert_close(chunk_2[1].weight, model_base.layers[5].weight) + assert_close(chunk_2[1].weight.grad, model_base.layers[5].weight.grad) + + if rank == 3: + # layer 3 + assert_close(chunk_3[0].weight, model_base.layers[3].weight) + assert_close(chunk_3[0].weight.grad, model_base.layers[3].weight.grad) + # layer 4 + assert_close(chunk_3[1].weight, model_base.layers[4].weight) + assert_close(chunk_3[1].weight.grad, model_base.layers[4].weight.grad) + + # clean memory + if rank == 0: + del input0, output0, input_grad0, input7, output7, input_grad7, loss + if rank == 1: + del input1, output1, input_grad1, input6, output6, input_grad6 + if rank == 2: + del input2, output2, input_grad2, input5, output5, input_grad5 + if rank == 3: + del input3, output3, input_grad3, input4, output4, input_grad4 + # print(f"After del device: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + + del loss_base, output_base + + print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + + +@rerun_if_address_is_in_use() +def test_dx_dw_dist(): + # spawn( + # model_chunk_dx_dw_communication, + # nprocs=2, + # ) + + spawn( + model_chunk_dx_dw_comm_interleaved, + nprocs=4, + ) + + +if __name__ == "__main__": + # test_dx_dw_split() + # test_double_dx_dw_split_nsync() + # test_double_dx_dw_split_sync() + # mem_dx_dw() + # activation_dx_dw() + # model_chunk_dx_dw() + + test_dx_dw_dist() diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py new file mode 100644 index 000000000000..fbc4df3ac448 --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -0,0 +1,341 @@ +from copy import deepcopy +from typing import Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +class MlpModel(nn.Module): + def __init__(self, in_dim, out_dim, num_layers): + super().__init__() + self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: + num_params = 0 + num_params_trainable = 0 + for p in model.parameters(): + num_params += p.numel() + if p.requires_grad: + num_params_trainable += p.numel() + return num_params, num_params_trainable + + +def test_zerobubble_pipeline_base( + rank: int, + world_size: int, + port: int, +): + # init dist + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + pg_mesh = ProcessGroupMesh(world_size) + + stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=world_size) + + scheduler = ZeroBubbleVPipeScheduler( + schedule=[], + stage_manager=stage_manager, + num_model_chunks=world_size, + num_microbatch=1, + overlap_p2p=False, + ) + + rank = dist.get_rank() + + # init model and input + num_layers = 8 + in_dim = out_dim = 2048 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) + + input0.clone() + deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + chunk_0 = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + chunk_0.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + chunk_1 = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + chunk_1.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + chunk_2 = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + chunk_2.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + chunk_3 = torch.nn.Sequential().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + chunk_3.append(sub_model) + print( + f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + def criterion(x, *args, **kwargs): + return (x * x).mean() + + ########################## + # Step1: fwd + ########################## + ###### + # fwd 1->4 + ###### + # chunk 0 id 0 (layer 0) fwd + if rank == 0: + chunk_id = 0 + scheduler.schedule_f( + scheduled_node=None, + model_chunk=chunk_0, + model_chunk_id=chunk_id, + input_obj=input0, + criterion=criterion, + accum_loss=None, + outputs=None, + ) + print( + f"chunk 0 id 0 (layer 0)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + # chunk 1 id 0 (layer 1) fwd + if rank == 1: + chunk_id = 0 + scheduler.schedule_f( + scheduled_node=None, + model_chunk=chunk_1, + model_chunk_id=chunk_id, + input_obj=None, + criterion=criterion, + accum_loss=None, + outputs=None, + ) + print( + f"chunk 1 id 0 (layer 1)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + # chunk 2 id 0 (layer 2) fwd + if rank == 2: + chunk_id = 0 + scheduler.schedule_f( + scheduled_node=None, + model_chunk=chunk_2, + model_chunk_id=chunk_id, + input_obj=None, + criterion=criterion, + accum_loss=None, + outputs=None, + ) + print( + f"chunk 2 id 0 (layer 2)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + # chunk 3 id 0 (layer 3) fwd + if rank == 3: + chunk_id = 0 + scheduler.schedule_f( + scheduled_node=None, + model_chunk=chunk_3, + model_chunk_id=chunk_id, + input_obj=None, + criterion=criterion, + accum_loss=None, + outputs=None, + ) + print( + f"chunk 3 id 0 (layer 3)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + ###### + # fwd 4->1 + ###### + + if rank == 3: + chunk_id = 1 + scheduler.schedule_f( + scheduled_node=None, + model_chunk=chunk_3, + model_chunk_id=chunk_id, + input_obj=None, + criterion=criterion, + accum_loss=None, + outputs=None, + ) + print( + f"chunk 3 id 1 (layer 4)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + if rank == 2: + chunk_id = 1 + scheduler.schedule_f( + scheduled_node=None, + model_chunk=chunk_2, + model_chunk_id=chunk_id, + input_obj=None, + criterion=criterion, + accum_loss=None, + outputs=None, + ) + print( + f"chunk 2 id 1 (layer 5)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + if rank == 1: + chunk_id = 1 + scheduler.schedule_f( + scheduled_node=None, + model_chunk=chunk_1, + model_chunk_id=chunk_id, + input_obj=None, + criterion=criterion, + accum_loss=None, + outputs=None, + ) + print( + f"chunk 1 id 1 (layer 6)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + if rank == 0: + chunk_id = 1 + scheduler.schedule_f( + scheduled_node=None, + model_chunk=chunk_0, + model_chunk_id=chunk_id, + input_obj=None, + criterion=criterion, + accum_loss=None, + outputs=None, + ) + # print(f"fwd output {output7}") + print( + f"chunk 0 id 1 (layer 7)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + ########################## + # Step2: bwd + ########################## + ###### + # bwd rank 4->1 + ###### + # chunk 0 id 1 (layer 7) bwd + if rank == 0: + chunk_id = 1 + scheduler.schedule_b( + scheduled_node=None, + model_chunk=chunk_0, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + + # # chunk 1 id 1 (layer 6) bwd + if rank == 1: + chunk_id = 1 + scheduler.schedule_b( + scheduled_node=None, + model_chunk=chunk_1, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + + # chunk 2 id 1 (layer 5) bwd + if rank == 2: + chunk_id = 1 + scheduler.schedule_b( + scheduled_node=None, + model_chunk=chunk_2, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + + # chunk 3 id 1 (layer 4) bwd + if rank == 3: + chunk_id = 1 + scheduler.schedule_b( + scheduled_node=None, + model_chunk=chunk_3, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + + # ###### + # # bwd rank 1->4 + # ###### + + # chunk 3 id 0 (layer 3) bwd + if rank == 3: + chunk_id = 0 + scheduler.schedule_b( + scheduled_node=None, + model_chunk=chunk_3, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + # print(f"input_grad3 {input_grad3}") + + # chunk 2 id 0 (layer 2) bwd + if rank == 2: + chunk_id = 0 + scheduler.schedule_b( + scheduled_node=None, + model_chunk=chunk_2, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + # print(f"input_grad2 {input_grad2}") + + # chunk 1 id 0 (layer 1) bwd + if rank == 1: + chunk_id = 0 + scheduler.schedule_b( + scheduled_node=None, + model_chunk=chunk_1, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + + # chunk 0 id 0 (layer 0) bwd + if rank == 0: + chunk_id = 0 + scheduler.schedule_b( + scheduled_node=None, + model_chunk=chunk_0, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + # print(f"input_grad0 {input_grad0}") + + +# @pytest.mark.dist +# @pytest.mark.parametrize("num_microbatch", [4]) +# @pytest.mark.parametrize("batch_size", [4]) +# @pytest.mark.parametrize("num_model_chunk", [2]) +@rerun_if_address_is_in_use() +def test_pp(): + spawn( + test_zerobubble_pipeline_base, + nprocs=4, + ) + + +if __name__ == "__main__": + + test_pp()