From 2f483baa095484cec256c09c3cb4a1720c5f1a23 Mon Sep 17 00:00:00 2001 From: Mayank Mishra <32954280+mayank31398@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:56:10 -0400 Subject: [PATCH] refactor dtensors (#54) Signed-off-by: Mayank Mishra --- dolomite_engine/distributed/__init__.py | 6 + dolomite_engine/distributed/dtensors.py | 83 +++++++++ dolomite_engine/finetune.py | 5 +- .../hf_models/mixins/dense_TP/main.py | 3 +- .../hf_models/mixins/moe_TP/main.py | 2 +- .../hf_models/modeling_utils_TP/TP.py | 164 ------------------ .../hf_models/modeling_utils_TP/__init__.py | 12 +- .../hf_models/modeling_utils_TP/dropout.py | 3 +- .../modeling_utils_TP/dtensor_module.py | 2 +- .../hf_models/modeling_utils_TP/embedding.py | 8 +- .../hf_models/modeling_utils_TP/linear.py | 66 +++---- .../hf_models/modeling_utils_TP/lm_head.py | 3 +- .../normalization/layernorm/base.py | 14 +- .../normalization/rmsnorm/base.py | 8 +- .../models/moe_dolomite_TP/moe_TP/scatter.py | 23 +-- dolomite_engine/model_wrapper/pretraining.py | 2 +- dolomite_engine/pretrain.py | 5 +- dolomite_engine/train_utils.py | 14 +- tests/hf_models/multi_gpu/dcp/train.yml | 2 +- .../multi_gpu/unsharding/unsharding.py | 5 +- 20 files changed, 159 insertions(+), 271 deletions(-) create mode 100644 dolomite_engine/distributed/dtensors.py diff --git a/dolomite_engine/distributed/__init__.py b/dolomite_engine/distributed/__init__.py index 433d556f..e0b81612 100644 --- a/dolomite_engine/distributed/__init__.py +++ b/dolomite_engine/distributed/__init__.py @@ -25,6 +25,12 @@ from ..enums import FP8Backend from ..gradient_checkpointing import apply_gradient_checkpointing from ..utils import ProcessGroupManager, get_module_class_from_name, log_rank_0, string_to_torch_dtype +from .dtensors import ( + dtensor_to_tensor, + modify_state_dict_to_dtensor_dict, + tensor_to_dtensor, + use_async_tensor_parallel, +) from .fp8 import convert_model_to_transformer_engine diff --git a/dolomite_engine/distributed/dtensors.py b/dolomite_engine/distributed/dtensors.py new file mode 100644 index 00000000..eb1eae4f --- /dev/null +++ b/dolomite_engine/distributed/dtensors.py @@ -0,0 +1,83 @@ +import torch +import torch.distributed +import torch.nn as nn +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import Placement +from torch.distributed.device_mesh import DeviceMesh + + +def tensor_to_dtensor( + tensor: torch.Tensor, + device_mesh: DeviceMesh, + current_placement: Placement | list[Placement], + desired_placement: Placement | list[Placement] | None = None, + run_check: bool = False, +) -> DTensor: + if isinstance(tensor, DTensor): + return tensor + + if isinstance(current_placement, Placement): + current_placement = [current_placement] + + dtensor = DTensor.from_local(tensor, device_mesh=device_mesh, run_check=run_check, placements=current_placement) + + if desired_placement is not None: + if isinstance(desired_placement, Placement): + desired_placement = [desired_placement] + + dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False) + + return dtensor + + +def dtensor_to_tensor( + dtensor: DTensor, + device_mesh: DeviceMesh | None = None, + desired_placement: Placement | list[Placement] | None = None, + grad_placement: Placement | list[Placement] | None = None, +) -> torch.Tensor: + if not isinstance(dtensor, DTensor): + return dtensor + + if desired_placement is not None: + if isinstance(desired_placement, Placement): + desired_placement = [desired_placement] + + assert device_mesh is not None + + dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False) + + if grad_placement is not None and isinstance(grad_placement, Placement): + grad_placement = [grad_placement] + + tensor = dtensor.to_local(grad_placements=grad_placement) + + return tensor + + +@torch.no_grad() +def modify_state_dict_to_dtensor_dict(module: nn.Module, state_dict: dict, prefix: str, strip_keys: bool) -> dict: + module_state_dict = module.state_dict() + + result = {} + for key, tensor in state_dict.items(): + if key.startswith(prefix): + stripped_key = key.split(prefix)[1] if strip_keys and prefix != "" else key + + param = module_state_dict[stripped_key] + device_mesh = param.device_mesh + placements = param.placements + + if isinstance(tensor, DTensor): + assert tensor.device_mesh == device_mesh + assert tensor.placements == placements + + result[key] = tensor + else: + result[key] = tensor_to_dtensor(tensor, device_mesh=device_mesh, current_placement=placements) + + return result + + +def use_async_tensor_parallel() -> bool: + return torch._inductor.config._micro_pipeline_tp diff --git a/dolomite_engine/finetune.py b/dolomite_engine/finetune.py index 58d879d6..bd6839b7 100644 --- a/dolomite_engine/finetune.py +++ b/dolomite_engine/finetune.py @@ -13,7 +13,7 @@ from .checkpointing import load_checkpoint_for_training, save_checkpoint from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer, log_model_optimizer_container from .data import ResumableDataLoader, custom_iterator, get_dataloader, get_next_batch -from .distributed import wrap_model_container_for_distributed_training +from .distributed import dtensor_to_tensor, wrap_model_container_for_distributed_training from .enums import DatasetSplit, FP8Backend, Mode, TuningMethod from .model_wrapper import ModelWrapperForFinetuning, get_model_container from .optimization import get_optimizer_container, get_scheduler_container @@ -197,8 +197,7 @@ def evaluate( metrics_tracker = metrics_tracker / num_steps for key in metrics_tracker: - if isinstance(metrics_tracker[key], DTensor): - metrics_tracker[key] = metrics_tracker[key].to_local() + metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) diff --git a/dolomite_engine/hf_models/mixins/dense_TP/main.py b/dolomite_engine/hf_models/mixins/dense_TP/main.py index 2e05a70c..52d4599a 100644 --- a/dolomite_engine/hf_models/mixins/dense_TP/main.py +++ b/dolomite_engine/hf_models/mixins/dense_TP/main.py @@ -9,10 +9,11 @@ from transformers import DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ....distributed import dtensor_to_tensor, tensor_to_dtensor from ....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible from ...config import CommonConfig from ...enums import PositionEmbeddingType -from ...modeling_utils_TP import LMHead_TP, dtensor_to_tensor, tensor_to_dtensor +from ...modeling_utils_TP import LMHead_TP from ..dense import CausalLMModelMixin from .base import PreTrainedModelMixin_TP diff --git a/dolomite_engine/hf_models/mixins/moe_TP/main.py b/dolomite_engine/hf_models/mixins/moe_TP/main.py index 8f5de698..71d926ed 100644 --- a/dolomite_engine/hf_models/mixins/moe_TP/main.py +++ b/dolomite_engine/hf_models/mixins/moe_TP/main.py @@ -3,7 +3,7 @@ from transformers import DynamicCache from transformers.modeling_outputs import MoeCausalLMOutputWithPast -from ...modeling_utils_TP import dtensor_to_tensor, tensor_to_dtensor +from ....distributed import dtensor_to_tensor, tensor_to_dtensor from ..dense_TP import CausalLMModelMixin_TP from ..moe import CausalLMMoEModelMixin, MoeModelOutputWithPastAndAuxLoss diff --git a/dolomite_engine/hf_models/modeling_utils_TP/TP.py b/dolomite_engine/hf_models/modeling_utils_TP/TP.py index d575f098..f4d05b63 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/TP.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/TP.py @@ -1,10 +1,6 @@ import torch import torch.distributed -import torch.distributed._functional_collectives as funcol -import torch.nn as nn -from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Placement, Replicate, Shard -from torch.distributed.device_mesh import DeviceMesh from ...utils import ProcessGroupManager, divide_if_divisible @@ -41,79 +37,6 @@ def tensor_parallel_split_safetensor_slice(slice, dim: int, start_end: tuple[int return output -def tensor_to_dtensor( - tensor: torch.Tensor, - device_mesh: DeviceMesh, - current_placement: Placement | list[Placement], - desired_placement: Placement | list[Placement] | None = None, - run_check: bool = False, -) -> DTensor: - if isinstance(tensor, DTensor): - return tensor - - if isinstance(current_placement, Placement): - current_placement = [current_placement] - - dtensor = DTensor.from_local(tensor, device_mesh=device_mesh, run_check=run_check, placements=current_placement) - - if desired_placement is not None: - if isinstance(desired_placement, Placement): - desired_placement = [desired_placement] - - dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False) - - return dtensor - - -def dtensor_to_tensor( - dtensor: DTensor, - device_mesh: DeviceMesh | None = None, - desired_placement: Placement | list[Placement] | None = None, - grad_placement: Placement | list[Placement] | None = None, -) -> torch.Tensor: - if not isinstance(dtensor, DTensor): - return dtensor - - if desired_placement is not None: - if isinstance(desired_placement, Placement): - desired_placement = [desired_placement] - - assert device_mesh is not None - - dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False) - - if grad_placement is not None and isinstance(grad_placement, Placement): - grad_placement = [grad_placement] - - tensor = dtensor.to_local(grad_placements=grad_placement) - - return tensor - - -@torch.no_grad() -def modify_state_dict_to_dtensor_dict(module: nn.Module, state_dict: dict, prefix: str, strip_keys: bool) -> dict: - module_state_dict = module.state_dict() - - result = {} - for key, tensor in state_dict.items(): - if key.startswith(prefix): - stripped_key = key.split(prefix)[1] if strip_keys and prefix != "" else key - - param = module_state_dict[stripped_key] - device_mesh = param.device_mesh - placements = param.placements - - if isinstance(tensor, DTensor): - assert tensor.device_mesh == device_mesh - assert tensor.placements == placements - - result[key] = tensor - else: - result[key] = tensor_to_dtensor(tensor, device_mesh=device_mesh, current_placement=placements) - - return result - - def get_module_placements(use_padding_free_transformer: bool, sequence_parallel: bool) -> Placement: if sequence_parallel: if use_padding_free_transformer: @@ -124,90 +47,3 @@ def get_module_placements(use_padding_free_transformer: bool, sequence_parallel: placement = Replicate() return placement - - -def _tensor_parallel_all_reduce(x: torch.Tensor) -> torch.Tensor: - if ProcessGroupManager.get_tensor_parallel_world_size() == 1: - return x - - return funcol.all_reduce(x, reduceOp="sum", group=ProcessGroupManager.get_tensor_parallel_group()) - - -def _tensor_parallel_all_gather(x: torch.Tensor, dim: int) -> torch.Tensor: - if ProcessGroupManager.get_tensor_parallel_world_size() == 1: - return x - - return funcol.all_gather_tensor(x, gather_dim=dim, group=ProcessGroupManager.get_tensor_parallel_group()) - - -def _tensor_parallel_reduce_scatter(x: torch.Tensor, dim: int) -> torch.Tensor: - if ProcessGroupManager.get_tensor_parallel_world_size() == 1: - return x - - return funcol.reduce_scatter_tensor( - x, reduceOp="sum", scatter_dim=dim, group=ProcessGroupManager.get_tensor_parallel_group() - ) - - -class _CopyToTensorParallelRegion(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor) -> torch.Tensor: - return x - - @staticmethod - def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor: - return _tensor_parallel_all_reduce(x_grad) - - -class _ReduceFromTensorParallelRegion(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor) -> torch.Tensor: - return _tensor_parallel_all_reduce(x) - - @staticmethod - def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor: - return x_grad - - -class _AllGatherFromSequenceParallelRegion(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, dim: int) -> torch.Tensor: - ctx.dim = dim - return _tensor_parallel_all_gather(x, dim=dim) - - @staticmethod - def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor: - dim = ctx.dim - return _tensor_parallel_reduce_scatter(x_grad, dim=dim), None - - -class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, dim: int) -> torch.Tensor: - ctx.dim = dim - return _tensor_parallel_reduce_scatter(x, dim=dim) - - @staticmethod - def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor: - dim = ctx.dim - return _tensor_parallel_all_gather(x_grad, dim=dim), None - - -def copy_to_tensor_parallel_region(x: torch.Tensor) -> torch.Tensor: - return _CopyToTensorParallelRegion.apply(x) - - -def reduce_from_tensor_parallel_region(x: torch.Tensor) -> torch.Tensor: - return _ReduceFromTensorParallelRegion.apply(x) - - -def all_gather_from_sequence_parallel_region(x: torch.Tensor, dim: int) -> torch.Tensor: - return _AllGatherFromSequenceParallelRegion.apply(x, dim) - - -def reduce_scatter_to_sequence_parallel_region(x: torch.Tensor, dim: int) -> torch.Tensor: - return _ReduceScatterToSequenceParallelRegion.apply(x, dim) - - -def use_async_tensor_parallel() -> bool: - return torch._inductor.config._micro_pipeline_tp diff --git a/dolomite_engine/hf_models/modeling_utils_TP/__init__.py b/dolomite_engine/hf_models/modeling_utils_TP/__init__.py index 2c9cf876..ce605a3f 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/__init__.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/__init__.py @@ -6,14 +6,4 @@ from .lm_head import LMHead_TP from .normalization import get_normalization_function_TP from .position_embedding import Alibi_TP -from .TP import ( - all_gather_from_sequence_parallel_region, - copy_to_tensor_parallel_region, - dtensor_to_tensor, - get_module_placements, - modify_state_dict_to_dtensor_dict, - reduce_from_tensor_parallel_region, - reduce_scatter_to_sequence_parallel_region, - tensor_parallel_split_safetensor_slice, - tensor_to_dtensor, -) +from .TP import get_module_placements, tensor_parallel_split_safetensor_slice diff --git a/dolomite_engine/hf_models/modeling_utils_TP/dropout.py b/dolomite_engine/hf_models/modeling_utils_TP/dropout.py index 43048a2a..7ce3fb77 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/dropout.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/dropout.py @@ -1,8 +1,9 @@ import torch import torch.nn as nn +from ...distributed import dtensor_to_tensor, tensor_to_dtensor from ...utils import ProcessGroupManager -from .TP import dtensor_to_tensor, get_module_placements, tensor_to_dtensor +from .TP import get_module_placements class Dropout_TP(nn.Dropout): diff --git a/dolomite_engine/hf_models/modeling_utils_TP/dtensor_module.py b/dolomite_engine/hf_models/modeling_utils_TP/dtensor_module.py index 152b757b..4cfb40f3 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/dtensor_module.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/dtensor_module.py @@ -2,7 +2,7 @@ import torch.nn as nn -from .TP import modify_state_dict_to_dtensor_dict +from ...distributed import modify_state_dict_to_dtensor_dict class DTensorModule(nn.Module): diff --git a/dolomite_engine/hf_models/modeling_utils_TP/embedding.py b/dolomite_engine/hf_models/modeling_utils_TP/embedding.py index 169a3bd8..32da1d26 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/embedding.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/embedding.py @@ -2,13 +2,13 @@ import torch import torch.nn as nn -from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Replicate, Shard +from ...distributed import dtensor_to_tensor, tensor_to_dtensor from ...utils import ProcessGroupManager, divide_if_divisible from ..modeling_utils import ParameterizedEmbedding from .dtensor_module import DTensorModule -from .TP import dtensor_to_tensor, get_module_placements, tensor_to_dtensor +from .TP import get_module_placements class Embedding_TP(ParameterizedEmbedding, DTensorModule): @@ -43,8 +43,8 @@ def __init__( placement = Replicate() self.weight = nn.Parameter( - DTensor.from_local( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[placement] + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=placement ) ) diff --git a/dolomite_engine/hf_models/modeling_utils_TP/linear.py b/dolomite_engine/hf_models/modeling_utils_TP/linear.py index 4d3aaed9..258cf055 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/linear.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/linear.py @@ -1,21 +1,13 @@ import torch import torch.distributed import torch.nn as nn -import torch.nn.functional as F -from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Partial, Replicate, Shard +from ...distributed import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel from ...utils import ProcessGroupManager, divide_if_divisible from ..modeling_utils import ParameterizedLinear from .dtensor_module import DTensorModule -from .TP import ( - all_gather_from_sequence_parallel_region, - copy_to_tensor_parallel_region, - dtensor_to_tensor, - get_module_placements, - tensor_to_dtensor, - use_async_tensor_parallel, -) +from .TP import get_module_placements class ReplicatedLinear(ParameterizedLinear, DTensorModule): @@ -35,14 +27,16 @@ def __init__( self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() self.weight = nn.Parameter( - DTensor.from_local( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Replicate()] + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() ) ) if bias: self.bias = nn.Parameter( - DTensor.from_local( - self.bias, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Replicate()] + tensor_to_dtensor( + self.bias, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Replicate(), ) ) @@ -91,42 +85,28 @@ def __init__( ) self.weight = nn.Parameter( - DTensor.from_local( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Shard(0)] + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) ) ) if bias: self.bias = nn.Parameter( - DTensor.from_local( - self.bias, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Shard(0)] + tensor_to_dtensor( + self.bias, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) ) ) self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) - self.use_padding_free_transformer = use_padding_free_transformer - self.sequence_parallel = sequence_parallel - self.use_async_tensor_parallel = use_async_tensor_parallel() - - if self.use_async_tensor_parallel: + if use_async_tensor_parallel(): self.compile() def forward(self, input: torch.Tensor) -> torch.Tensor: - # FIXME dtensor redistribute uses alltoall for large number of GPUs - if self.use_async_tensor_parallel: - if self.sequence_parallel: - input = all_gather_from_sequence_parallel_region( - input, dim=0 if self.use_padding_free_transformer else 1 - ) - else: - input = copy_to_tensor_parallel_region(input) - - input = F.linear(input, self.weight.to_local(), None if self.bias is None else self.bias.to_local()) - else: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=self.input_placement) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=Shard(-1)) - + input = tensor_to_dtensor( + input, device_mesh=self.tp_mesh, current_placement=self.input_placement, desired_placement=Replicate() + ) + input = super().forward(input) + input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=Shard(-1)) return input def extra_repr(self) -> str: @@ -166,14 +146,16 @@ def __init__( ) self.weight = nn.Parameter( - DTensor.from_local( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Shard(1)] + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) ) ) if bias: self.bias = nn.Parameter( - DTensor.from_local( - self.bias, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Replicate()] + tensor_to_dtensor( + self.bias, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Replicate(), ) ) diff --git a/dolomite_engine/hf_models/modeling_utils_TP/lm_head.py b/dolomite_engine/hf_models/modeling_utils_TP/lm_head.py index 93c0b6ed..d5ea50d4 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/lm_head.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/lm_head.py @@ -3,8 +3,9 @@ from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed.device_mesh import DeviceMesh +from ...distributed import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel from .embedding import Embedding_TP -from .TP import dtensor_to_tensor, get_module_placements, tensor_to_dtensor, use_async_tensor_parallel +from .TP import get_module_placements class LMHead_TP(Embedding_TP): diff --git a/dolomite_engine/hf_models/modeling_utils_TP/normalization/layernorm/base.py b/dolomite_engine/hf_models/modeling_utils_TP/normalization/layernorm/base.py index eee40a68..3444d99c 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/normalization/layernorm/base.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/normalization/layernorm/base.py @@ -1,11 +1,11 @@ import torch import torch.nn as nn -from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Replicate +from .....distributed import dtensor_to_tensor, tensor_to_dtensor from .....utils import ProcessGroupManager from ...dtensor_module import DTensorModule -from ...TP import dtensor_to_tensor, get_module_placements, tensor_to_dtensor +from ...TP import get_module_placements class LayerNorm_TP(nn.LayerNorm, DTensorModule): @@ -21,13 +21,15 @@ def __init__( self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() self.weight = nn.Parameter( - DTensor.from_local( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Replicate()] + tensor_to_dtensor( + self.weight, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Replicate(), ) ) self.bias = nn.Parameter( - DTensor.from_local( - self.bias, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Replicate()] + tensor_to_dtensor( + self.bias, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() ) ) diff --git a/dolomite_engine/hf_models/modeling_utils_TP/normalization/rmsnorm/base.py b/dolomite_engine/hf_models/modeling_utils_TP/normalization/rmsnorm/base.py index c3f43fac..3c90892f 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/normalization/rmsnorm/base.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/normalization/rmsnorm/base.py @@ -1,11 +1,11 @@ import torch import torch.nn as nn -from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Replicate +from .....distributed import dtensor_to_tensor, tensor_to_dtensor from .....utils import ProcessGroupManager from ...dtensor_module import DTensorModule -from ...TP import dtensor_to_tensor, get_module_placements, tensor_to_dtensor +from ...TP import get_module_placements class RMSNorm_TP(nn.RMSNorm, DTensorModule): @@ -21,8 +21,8 @@ def __init__( self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() self.weight = nn.Parameter( - DTensor.from_local( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Replicate()] + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() ) ) diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py index af80888b..352f2104 100644 --- a/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py @@ -7,10 +7,11 @@ from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Partial, Replicate, Shard +from .....distributed import dtensor_to_tensor, tensor_to_dtensor from .....utils import ProcessGroupManager, divide_if_divisible, is_kernel_hyperdrive_available from ....enums import InitMethod from ....modeling_utils import ParameterizedTransposedLinear, get_activation_function, is_glu -from ....modeling_utils_TP import Dropout_TP, DTensorModule, dtensor_to_tensor, tensor_to_dtensor +from ....modeling_utils_TP import Dropout_TP, DTensorModule from ...moe_dolomite import MoEDolomiteConfig from ...moe_dolomite.moe import ScatterMoE from ...moe_dolomite.moe.scatter import ParameterizedScatteredExperts @@ -35,8 +36,8 @@ def __init__( ) self.weight = nn.Parameter( - DTensor.from_local( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Replicate()] + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() ) ) @@ -71,11 +72,8 @@ def __init__( ) self.weight = nn.Parameter( - DTensor.from_local( - self.weight, - device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), - placements=[Shard(0)], - run_check=False, + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) ) ) @@ -92,7 +90,7 @@ def forward( ) -> torch.Tensor: return scattered_experts( inputs=input, - expert_weights=self.weight.to_local().permute(1, 2, 0), + expert_weights=dtensor_to_tensor(self.weight).permute(1, 2, 0), k=k, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, @@ -134,11 +132,8 @@ def __init__( ) self.weight = nn.Parameter( - DTensor.from_local( - self.weight, - device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), - placements=[Shard(-1)], - run_check=False, + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(-1) ) ) diff --git a/dolomite_engine/model_wrapper/pretraining.py b/dolomite_engine/model_wrapper/pretraining.py index fb997b75..5cd7a2b0 100644 --- a/dolomite_engine/model_wrapper/pretraining.py +++ b/dolomite_engine/model_wrapper/pretraining.py @@ -7,8 +7,8 @@ from torch.distributed.tensor.parallel import loss_parallel from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM +from ..distributed import tensor_to_dtensor from ..enums import AttentionImplementation, Mode, MoEImplementation -from ..hf_models.modeling_utils_TP import tensor_to_dtensor from ..utils import ProcessGroupManager from .base import ModelWrapper diff --git a/dolomite_engine/pretrain.py b/dolomite_engine/pretrain.py index 450b6b9b..cb275160 100644 --- a/dolomite_engine/pretrain.py +++ b/dolomite_engine/pretrain.py @@ -15,7 +15,7 @@ from .communication import Communication from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer, log_model_optimizer_container from .data import get_megatron_gpt_dataloaders, get_next_batch -from .distributed import wrap_model_container_for_distributed_training +from .distributed import dtensor_to_tensor, wrap_model_container_for_distributed_training from .enums import FP8Backend, Mode, TuningMethod from .model_wrapper import get_model_container from .optimization import get_optimizer_container, get_scheduler_container @@ -296,8 +296,7 @@ def evaluate( metrics_tracker = metrics_tracker / eval_steps for key in metrics_tracker: - if isinstance(metrics_tracker[key], DTensor): - metrics_tracker[key] = metrics_tracker[key].to_local() + metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) diff --git a/dolomite_engine/train_utils.py b/dolomite_engine/train_utils.py index e78b2854..26ade6ad 100644 --- a/dolomite_engine/train_utils.py +++ b/dolomite_engine/train_utils.py @@ -10,6 +10,7 @@ from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer from .data import ResumableDataLoader, get_next_batch +from .distributed import dtensor_to_tensor from .enums import GradientCheckpointingMethod from .hf_models import is_custom_model from .hf_models.modeling_utils import is_glu @@ -153,12 +154,7 @@ def _train_step_with_pipeline_parallel( metrics_tracker = MetricsTrackingDict({}) with torch.inference_mode(): - grad_norm = sum(grad_norm) - if not isinstance(grad_norm, torch.Tensor): - grad_norm = torch.tensor(grad_norm, device=torch.cuda.current_device()) - elif isinstance(grad_norm, DTensor): - grad_norm = grad_norm.to_local() - + grad_norm = dtensor_to_tensor(sum(grad_norm)) torch.distributed.all_reduce(grad_norm, group=ProcessGroupManager.get_pipeline_parallel_group()) if is_last_pipeline_rank: @@ -170,8 +166,7 @@ def _train_step_with_pipeline_parallel( metrics_tracker["grad_norm"] = grad_norm for key in metrics_tracker: - if isinstance(metrics_tracker[key], DTensor): - metrics_tracker[key] = metrics_tracker[key].to_local() + metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) @@ -263,8 +258,7 @@ def _train_step_without_pipeline_parallel( ) for key in metrics_tracker: - if isinstance(metrics_tracker[key], DTensor): - metrics_tracker[key] = metrics_tracker[key].to_local() + metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) diff --git a/tests/hf_models/multi_gpu/dcp/train.yml b/tests/hf_models/multi_gpu/dcp/train.yml index bfa7c4d0..e6ea6021 100644 --- a/tests/hf_models/multi_gpu/dcp/train.yml +++ b/tests/hf_models/multi_gpu/dcp/train.yml @@ -79,5 +79,5 @@ distributed_args: stage: 3 tensor_parallel_world_size: 2 pipeline_parallel_world_size: 2 - num_pipeline_stages: 4 + num_pipeline_stages: 2 pipeline_parallel_schedule: 1F1B diff --git a/tests/hf_models/multi_gpu/unsharding/unsharding.py b/tests/hf_models/multi_gpu/unsharding/unsharding.py index 1898b60a..3a2af711 100644 --- a/tests/hf_models/multi_gpu/unsharding/unsharding.py +++ b/tests/hf_models/multi_gpu/unsharding/unsharding.py @@ -5,6 +5,7 @@ import torch.distributed from torch.distributed._tensor.api import DTensor +from dolomite_engine.distributed import dtensor_to_tensor from dolomite_engine.hf_models import ( AttentionHeadType, GPTDolomiteConfig, @@ -86,9 +87,7 @@ def run_check(fix: bool): config, tp_state_dict_unsharded, ProcessGroupManager.get_tensor_parallel_world_size() ) else: - cpu_state_dict = { - key: value.to_local() if isinstance(value, DTensor) else value for key, value in cpu_state_dict.items() - } + cpu_state_dict = {key: dtensor_to_tensor(value) for key, value in cpu_state_dict.items()} torch.save( cpu_state_dict, os.path.join(args.tmp_path, f"tp-{ProcessGroupManager.get_tensor_parallel_rank()}.pt") )