Skip to content

Commit

Permalink
refactor dtensors (#54)
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 authored Oct 30, 2024
1 parent ee7ed18 commit 214f6f9
Show file tree
Hide file tree
Showing 20 changed files with 159 additions and 271 deletions.
6 changes: 6 additions & 0 deletions dolomite_engine/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
from ..enums import FP8Backend
from ..gradient_checkpointing import apply_gradient_checkpointing
from ..utils import ProcessGroupManager, get_module_class_from_name, log_rank_0, string_to_torch_dtype
from .dtensors import (
dtensor_to_tensor,
modify_state_dict_to_dtensor_dict,
tensor_to_dtensor,
use_async_tensor_parallel,
)
from .fp8 import convert_model_to_transformer_engine


Expand Down
83 changes: 83 additions & 0 deletions dolomite_engine/distributed/dtensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import torch.distributed
import torch.nn as nn
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.placement_types import Placement
from torch.distributed.device_mesh import DeviceMesh


def tensor_to_dtensor(
tensor: torch.Tensor,
device_mesh: DeviceMesh,
current_placement: Placement | list[Placement],
desired_placement: Placement | list[Placement] | None = None,
run_check: bool = False,
) -> DTensor:
if isinstance(tensor, DTensor):
return tensor

if isinstance(current_placement, Placement):
current_placement = [current_placement]

dtensor = DTensor.from_local(tensor, device_mesh=device_mesh, run_check=run_check, placements=current_placement)

if desired_placement is not None:
if isinstance(desired_placement, Placement):
desired_placement = [desired_placement]

dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False)

return dtensor


def dtensor_to_tensor(
dtensor: DTensor,
device_mesh: DeviceMesh | None = None,
desired_placement: Placement | list[Placement] | None = None,
grad_placement: Placement | list[Placement] | None = None,
) -> torch.Tensor:
if not isinstance(dtensor, DTensor):
return dtensor

if desired_placement is not None:
if isinstance(desired_placement, Placement):
desired_placement = [desired_placement]

assert device_mesh is not None

dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False)

if grad_placement is not None and isinstance(grad_placement, Placement):
grad_placement = [grad_placement]

tensor = dtensor.to_local(grad_placements=grad_placement)

return tensor


@torch.no_grad()
def modify_state_dict_to_dtensor_dict(module: nn.Module, state_dict: dict, prefix: str, strip_keys: bool) -> dict:
module_state_dict = module.state_dict()

result = {}
for key, tensor in state_dict.items():
if key.startswith(prefix):
stripped_key = key.split(prefix)[1] if strip_keys and prefix != "" else key

param = module_state_dict[stripped_key]
device_mesh = param.device_mesh
placements = param.placements

if isinstance(tensor, DTensor):
assert tensor.device_mesh == device_mesh
assert tensor.placements == placements

result[key] = tensor
else:
result[key] = tensor_to_dtensor(tensor, device_mesh=device_mesh, current_placement=placements)

return result


def use_async_tensor_parallel() -> bool:
return torch._inductor.config._micro_pipeline_tp
5 changes: 2 additions & 3 deletions dolomite_engine/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .checkpointing import load_checkpoint_for_training, save_checkpoint
from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer, log_model_optimizer_container
from .data import ResumableDataLoader, custom_iterator, get_dataloader, get_next_batch
from .distributed import wrap_model_container_for_distributed_training
from .distributed import dtensor_to_tensor, wrap_model_container_for_distributed_training
from .enums import DatasetSplit, FP8Backend, Mode, TuningMethod
from .model_wrapper import ModelWrapperForFinetuning, get_model_container
from .optimization import get_optimizer_container, get_scheduler_container
Expand Down Expand Up @@ -197,8 +197,7 @@ def evaluate(
metrics_tracker = metrics_tracker / num_steps

for key in metrics_tracker:
if isinstance(metrics_tracker[key], DTensor):
metrics_tracker[key] = metrics_tracker[key].to_local()
metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key])

metrics_tracker = all_reduce_metrics_tracker(metrics_tracker)

Expand Down
3 changes: 2 additions & 1 deletion dolomite_engine/hf_models/mixins/dense_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from transformers import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast

from ....distributed import dtensor_to_tensor, tensor_to_dtensor
from ....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible
from ...config import CommonConfig
from ...enums import PositionEmbeddingType
from ...modeling_utils_TP import LMHead_TP, dtensor_to_tensor, tensor_to_dtensor
from ...modeling_utils_TP import LMHead_TP
from ..dense import CausalLMModelMixin
from .base import PreTrainedModelMixin_TP

Expand Down
2 changes: 1 addition & 1 deletion dolomite_engine/hf_models/mixins/moe_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import DynamicCache
from transformers.modeling_outputs import MoeCausalLMOutputWithPast

from ...modeling_utils_TP import dtensor_to_tensor, tensor_to_dtensor
from ....distributed import dtensor_to_tensor, tensor_to_dtensor
from ..dense_TP import CausalLMModelMixin_TP
from ..moe import CausalLMMoEModelMixin, MoeModelOutputWithPastAndAuxLoss

Expand Down
164 changes: 0 additions & 164 deletions dolomite_engine/hf_models/modeling_utils_TP/TP.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import torch
import torch.distributed
import torch.distributed._functional_collectives as funcol
import torch.nn as nn
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh

from ...utils import ProcessGroupManager, divide_if_divisible

Expand Down Expand Up @@ -41,79 +37,6 @@ def tensor_parallel_split_safetensor_slice(slice, dim: int, start_end: tuple[int
return output


def tensor_to_dtensor(
tensor: torch.Tensor,
device_mesh: DeviceMesh,
current_placement: Placement | list[Placement],
desired_placement: Placement | list[Placement] | None = None,
run_check: bool = False,
) -> DTensor:
if isinstance(tensor, DTensor):
return tensor

if isinstance(current_placement, Placement):
current_placement = [current_placement]

dtensor = DTensor.from_local(tensor, device_mesh=device_mesh, run_check=run_check, placements=current_placement)

if desired_placement is not None:
if isinstance(desired_placement, Placement):
desired_placement = [desired_placement]

dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False)

return dtensor


def dtensor_to_tensor(
dtensor: DTensor,
device_mesh: DeviceMesh | None = None,
desired_placement: Placement | list[Placement] | None = None,
grad_placement: Placement | list[Placement] | None = None,
) -> torch.Tensor:
if not isinstance(dtensor, DTensor):
return dtensor

if desired_placement is not None:
if isinstance(desired_placement, Placement):
desired_placement = [desired_placement]

assert device_mesh is not None

dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False)

if grad_placement is not None and isinstance(grad_placement, Placement):
grad_placement = [grad_placement]

tensor = dtensor.to_local(grad_placements=grad_placement)

return tensor


@torch.no_grad()
def modify_state_dict_to_dtensor_dict(module: nn.Module, state_dict: dict, prefix: str, strip_keys: bool) -> dict:
module_state_dict = module.state_dict()

result = {}
for key, tensor in state_dict.items():
if key.startswith(prefix):
stripped_key = key.split(prefix)[1] if strip_keys and prefix != "" else key

param = module_state_dict[stripped_key]
device_mesh = param.device_mesh
placements = param.placements

if isinstance(tensor, DTensor):
assert tensor.device_mesh == device_mesh
assert tensor.placements == placements

result[key] = tensor
else:
result[key] = tensor_to_dtensor(tensor, device_mesh=device_mesh, current_placement=placements)

return result


def get_module_placements(use_padding_free_transformer: bool, sequence_parallel: bool) -> Placement:
if sequence_parallel:
if use_padding_free_transformer:
Expand All @@ -124,90 +47,3 @@ def get_module_placements(use_padding_free_transformer: bool, sequence_parallel:
placement = Replicate()

return placement


def _tensor_parallel_all_reduce(x: torch.Tensor) -> torch.Tensor:
if ProcessGroupManager.get_tensor_parallel_world_size() == 1:
return x

return funcol.all_reduce(x, reduceOp="sum", group=ProcessGroupManager.get_tensor_parallel_group())


def _tensor_parallel_all_gather(x: torch.Tensor, dim: int) -> torch.Tensor:
if ProcessGroupManager.get_tensor_parallel_world_size() == 1:
return x

return funcol.all_gather_tensor(x, gather_dim=dim, group=ProcessGroupManager.get_tensor_parallel_group())


def _tensor_parallel_reduce_scatter(x: torch.Tensor, dim: int) -> torch.Tensor:
if ProcessGroupManager.get_tensor_parallel_world_size() == 1:
return x

return funcol.reduce_scatter_tensor(
x, reduceOp="sum", scatter_dim=dim, group=ProcessGroupManager.get_tensor_parallel_group()
)


class _CopyToTensorParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return x

@staticmethod
def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor:
return _tensor_parallel_all_reduce(x_grad)


class _ReduceFromTensorParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return _tensor_parallel_all_reduce(x)

@staticmethod
def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor:
return x_grad


class _AllGatherFromSequenceParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, dim: int) -> torch.Tensor:
ctx.dim = dim
return _tensor_parallel_all_gather(x, dim=dim)

@staticmethod
def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor:
dim = ctx.dim
return _tensor_parallel_reduce_scatter(x_grad, dim=dim), None


class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, dim: int) -> torch.Tensor:
ctx.dim = dim
return _tensor_parallel_reduce_scatter(x, dim=dim)

@staticmethod
def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor:
dim = ctx.dim
return _tensor_parallel_all_gather(x_grad, dim=dim), None


def copy_to_tensor_parallel_region(x: torch.Tensor) -> torch.Tensor:
return _CopyToTensorParallelRegion.apply(x)


def reduce_from_tensor_parallel_region(x: torch.Tensor) -> torch.Tensor:
return _ReduceFromTensorParallelRegion.apply(x)


def all_gather_from_sequence_parallel_region(x: torch.Tensor, dim: int) -> torch.Tensor:
return _AllGatherFromSequenceParallelRegion.apply(x, dim)


def reduce_scatter_to_sequence_parallel_region(x: torch.Tensor, dim: int) -> torch.Tensor:
return _ReduceScatterToSequenceParallelRegion.apply(x, dim)


def use_async_tensor_parallel() -> bool:
return torch._inductor.config._micro_pipeline_tp
12 changes: 1 addition & 11 deletions dolomite_engine/hf_models/modeling_utils_TP/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,4 @@
from .lm_head import LMHead_TP
from .normalization import get_normalization_function_TP
from .position_embedding import Alibi_TP
from .TP import (
all_gather_from_sequence_parallel_region,
copy_to_tensor_parallel_region,
dtensor_to_tensor,
get_module_placements,
modify_state_dict_to_dtensor_dict,
reduce_from_tensor_parallel_region,
reduce_scatter_to_sequence_parallel_region,
tensor_parallel_split_safetensor_slice,
tensor_to_dtensor,
)
from .TP import get_module_placements, tensor_parallel_split_safetensor_slice
3 changes: 2 additions & 1 deletion dolomite_engine/hf_models/modeling_utils_TP/dropout.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import torch.nn as nn

from ...distributed import dtensor_to_tensor, tensor_to_dtensor
from ...utils import ProcessGroupManager
from .TP import dtensor_to_tensor, get_module_placements, tensor_to_dtensor
from .TP import get_module_placements


class Dropout_TP(nn.Dropout):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch.nn as nn

from .TP import modify_state_dict_to_dtensor_dict
from ...distributed import modify_state_dict_to_dtensor_dict


class DTensorModule(nn.Module):
Expand Down
8 changes: 4 additions & 4 deletions dolomite_engine/hf_models/modeling_utils_TP/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import torch
import torch.nn as nn
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.placement_types import Replicate, Shard

from ...distributed import dtensor_to_tensor, tensor_to_dtensor
from ...utils import ProcessGroupManager, divide_if_divisible
from ..modeling_utils import ParameterizedEmbedding
from .dtensor_module import DTensorModule
from .TP import dtensor_to_tensor, get_module_placements, tensor_to_dtensor
from .TP import get_module_placements


class Embedding_TP(ParameterizedEmbedding, DTensorModule):
Expand Down Expand Up @@ -43,8 +43,8 @@ def __init__(
placement = Replicate()

self.weight = nn.Parameter(
DTensor.from_local(
self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[placement]
tensor_to_dtensor(
self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=placement
)
)

Expand Down
Loading

0 comments on commit 214f6f9

Please sign in to comment.