Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor dtensors #54

Merged
merged 6 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dolomite_engine/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
from ..enums import FP8Backend
from ..gradient_checkpointing import apply_gradient_checkpointing
from ..utils import ProcessGroupManager, get_module_class_from_name, log_rank_0, string_to_torch_dtype
from .dtensors import (
dtensor_to_tensor,
modify_state_dict_to_dtensor_dict,
tensor_to_dtensor,
use_async_tensor_parallel,
)
from .fp8 import convert_model_to_transformer_engine


Expand Down
83 changes: 83 additions & 0 deletions dolomite_engine/distributed/dtensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import torch.distributed
import torch.nn as nn
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.placement_types import Placement
from torch.distributed.device_mesh import DeviceMesh


def tensor_to_dtensor(
tensor: torch.Tensor,
device_mesh: DeviceMesh,
current_placement: Placement | list[Placement],
desired_placement: Placement | list[Placement] | None = None,
run_check: bool = False,
) -> DTensor:
if isinstance(tensor, DTensor):
return tensor

if isinstance(current_placement, Placement):
current_placement = [current_placement]

dtensor = DTensor.from_local(tensor, device_mesh=device_mesh, run_check=run_check, placements=current_placement)

if desired_placement is not None:
if isinstance(desired_placement, Placement):
desired_placement = [desired_placement]

dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False)

return dtensor


def dtensor_to_tensor(
dtensor: DTensor,
device_mesh: DeviceMesh | None = None,
desired_placement: Placement | list[Placement] | None = None,
grad_placement: Placement | list[Placement] | None = None,
) -> torch.Tensor:
if not isinstance(dtensor, DTensor):
return dtensor

if desired_placement is not None:
if isinstance(desired_placement, Placement):
desired_placement = [desired_placement]

assert device_mesh is not None

dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False)

if grad_placement is not None and isinstance(grad_placement, Placement):
grad_placement = [grad_placement]

tensor = dtensor.to_local(grad_placements=grad_placement)

return tensor


@torch.no_grad()
def modify_state_dict_to_dtensor_dict(module: nn.Module, state_dict: dict, prefix: str, strip_keys: bool) -> dict:
module_state_dict = module.state_dict()

result = {}
for key, tensor in state_dict.items():
if key.startswith(prefix):
stripped_key = key.split(prefix)[1] if strip_keys and prefix != "" else key

param = module_state_dict[stripped_key]
device_mesh = param.device_mesh
placements = param.placements

if isinstance(tensor, DTensor):
assert tensor.device_mesh == device_mesh
assert tensor.placements == placements

result[key] = tensor
else:
result[key] = tensor_to_dtensor(tensor, device_mesh=device_mesh, current_placement=placements)

return result


def use_async_tensor_parallel() -> bool:
return torch._inductor.config._micro_pipeline_tp
5 changes: 2 additions & 3 deletions dolomite_engine/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .checkpointing import load_checkpoint_for_training, save_checkpoint
from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer, log_model_optimizer_container
from .data import ResumableDataLoader, custom_iterator, get_dataloader, get_next_batch
from .distributed import wrap_model_container_for_distributed_training
from .distributed import dtensor_to_tensor, wrap_model_container_for_distributed_training
from .enums import DatasetSplit, FP8Backend, Mode, TuningMethod
from .model_wrapper import ModelWrapperForFinetuning, get_model_container
from .optimization import get_optimizer_container, get_scheduler_container
Expand Down Expand Up @@ -197,8 +197,7 @@ def evaluate(
metrics_tracker = metrics_tracker / num_steps

for key in metrics_tracker:
if isinstance(metrics_tracker[key], DTensor):
metrics_tracker[key] = metrics_tracker[key].to_local()
metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key])

metrics_tracker = all_reduce_metrics_tracker(metrics_tracker)

Expand Down
3 changes: 2 additions & 1 deletion dolomite_engine/hf_models/mixins/dense_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from transformers import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast

from ....distributed import dtensor_to_tensor, tensor_to_dtensor
from ....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible
from ...config import CommonConfig
from ...enums import PositionEmbeddingType
from ...modeling_utils_TP import LMHead_TP, dtensor_to_tensor, tensor_to_dtensor
from ...modeling_utils_TP import LMHead_TP
from ..dense import CausalLMModelMixin
from .base import PreTrainedModelMixin_TP

Expand Down
2 changes: 1 addition & 1 deletion dolomite_engine/hf_models/mixins/moe_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import DynamicCache
from transformers.modeling_outputs import MoeCausalLMOutputWithPast

from ...modeling_utils_TP import dtensor_to_tensor, tensor_to_dtensor
from ....distributed import dtensor_to_tensor, tensor_to_dtensor
from ..dense_TP import CausalLMModelMixin_TP
from ..moe import CausalLMMoEModelMixin, MoeModelOutputWithPastAndAuxLoss

Expand Down
164 changes: 0 additions & 164 deletions dolomite_engine/hf_models/modeling_utils_TP/TP.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import torch
import torch.distributed
import torch.distributed._functional_collectives as funcol
import torch.nn as nn
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh

from ...utils import ProcessGroupManager, divide_if_divisible

Expand Down Expand Up @@ -41,79 +37,6 @@ def tensor_parallel_split_safetensor_slice(slice, dim: int, start_end: tuple[int
return output


def tensor_to_dtensor(
tensor: torch.Tensor,
device_mesh: DeviceMesh,
current_placement: Placement | list[Placement],
desired_placement: Placement | list[Placement] | None = None,
run_check: bool = False,
) -> DTensor:
if isinstance(tensor, DTensor):
return tensor

if isinstance(current_placement, Placement):
current_placement = [current_placement]

dtensor = DTensor.from_local(tensor, device_mesh=device_mesh, run_check=run_check, placements=current_placement)

if desired_placement is not None:
if isinstance(desired_placement, Placement):
desired_placement = [desired_placement]

dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False)

return dtensor


def dtensor_to_tensor(
dtensor: DTensor,
device_mesh: DeviceMesh | None = None,
desired_placement: Placement | list[Placement] | None = None,
grad_placement: Placement | list[Placement] | None = None,
) -> torch.Tensor:
if not isinstance(dtensor, DTensor):
return dtensor

if desired_placement is not None:
if isinstance(desired_placement, Placement):
desired_placement = [desired_placement]

assert device_mesh is not None

dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=False)

if grad_placement is not None and isinstance(grad_placement, Placement):
grad_placement = [grad_placement]

tensor = dtensor.to_local(grad_placements=grad_placement)

return tensor


@torch.no_grad()
def modify_state_dict_to_dtensor_dict(module: nn.Module, state_dict: dict, prefix: str, strip_keys: bool) -> dict:
module_state_dict = module.state_dict()

result = {}
for key, tensor in state_dict.items():
if key.startswith(prefix):
stripped_key = key.split(prefix)[1] if strip_keys and prefix != "" else key

param = module_state_dict[stripped_key]
device_mesh = param.device_mesh
placements = param.placements

if isinstance(tensor, DTensor):
assert tensor.device_mesh == device_mesh
assert tensor.placements == placements

result[key] = tensor
else:
result[key] = tensor_to_dtensor(tensor, device_mesh=device_mesh, current_placement=placements)

return result


def get_module_placements(use_padding_free_transformer: bool, sequence_parallel: bool) -> Placement:
if sequence_parallel:
if use_padding_free_transformer:
Expand All @@ -124,90 +47,3 @@ def get_module_placements(use_padding_free_transformer: bool, sequence_parallel:
placement = Replicate()

return placement


def _tensor_parallel_all_reduce(x: torch.Tensor) -> torch.Tensor:
if ProcessGroupManager.get_tensor_parallel_world_size() == 1:
return x

return funcol.all_reduce(x, reduceOp="sum", group=ProcessGroupManager.get_tensor_parallel_group())


def _tensor_parallel_all_gather(x: torch.Tensor, dim: int) -> torch.Tensor:
if ProcessGroupManager.get_tensor_parallel_world_size() == 1:
return x

return funcol.all_gather_tensor(x, gather_dim=dim, group=ProcessGroupManager.get_tensor_parallel_group())


def _tensor_parallel_reduce_scatter(x: torch.Tensor, dim: int) -> torch.Tensor:
if ProcessGroupManager.get_tensor_parallel_world_size() == 1:
return x

return funcol.reduce_scatter_tensor(
x, reduceOp="sum", scatter_dim=dim, group=ProcessGroupManager.get_tensor_parallel_group()
)


class _CopyToTensorParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return x

@staticmethod
def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor:
return _tensor_parallel_all_reduce(x_grad)


class _ReduceFromTensorParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return _tensor_parallel_all_reduce(x)

@staticmethod
def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor:
return x_grad


class _AllGatherFromSequenceParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, dim: int) -> torch.Tensor:
ctx.dim = dim
return _tensor_parallel_all_gather(x, dim=dim)

@staticmethod
def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor:
dim = ctx.dim
return _tensor_parallel_reduce_scatter(x_grad, dim=dim), None


class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, dim: int) -> torch.Tensor:
ctx.dim = dim
return _tensor_parallel_reduce_scatter(x, dim=dim)

@staticmethod
def backward(ctx, x_grad: torch.Tensor) -> torch.Tensor:
dim = ctx.dim
return _tensor_parallel_all_gather(x_grad, dim=dim), None


def copy_to_tensor_parallel_region(x: torch.Tensor) -> torch.Tensor:
return _CopyToTensorParallelRegion.apply(x)


def reduce_from_tensor_parallel_region(x: torch.Tensor) -> torch.Tensor:
return _ReduceFromTensorParallelRegion.apply(x)


def all_gather_from_sequence_parallel_region(x: torch.Tensor, dim: int) -> torch.Tensor:
return _AllGatherFromSequenceParallelRegion.apply(x, dim)


def reduce_scatter_to_sequence_parallel_region(x: torch.Tensor, dim: int) -> torch.Tensor:
return _ReduceScatterToSequenceParallelRegion.apply(x, dim)


def use_async_tensor_parallel() -> bool:
return torch._inductor.config._micro_pipeline_tp
12 changes: 1 addition & 11 deletions dolomite_engine/hf_models/modeling_utils_TP/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,4 @@
from .lm_head import LMHead_TP
from .normalization import get_normalization_function_TP
from .position_embedding import Alibi_TP
from .TP import (
all_gather_from_sequence_parallel_region,
copy_to_tensor_parallel_region,
dtensor_to_tensor,
get_module_placements,
modify_state_dict_to_dtensor_dict,
reduce_from_tensor_parallel_region,
reduce_scatter_to_sequence_parallel_region,
tensor_parallel_split_safetensor_slice,
tensor_to_dtensor,
)
from .TP import get_module_placements, tensor_parallel_split_safetensor_slice
3 changes: 2 additions & 1 deletion dolomite_engine/hf_models/modeling_utils_TP/dropout.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import torch.nn as nn

from ...distributed import dtensor_to_tensor, tensor_to_dtensor
from ...utils import ProcessGroupManager
from .TP import dtensor_to_tensor, get_module_placements, tensor_to_dtensor
from .TP import get_module_placements


class Dropout_TP(nn.Dropout):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch.nn as nn

from .TP import modify_state_dict_to_dtensor_dict
from ...distributed import modify_state_dict_to_dtensor_dict


class DTensorModule(nn.Module):
Expand Down
8 changes: 4 additions & 4 deletions dolomite_engine/hf_models/modeling_utils_TP/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import torch
import torch.nn as nn
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.placement_types import Replicate, Shard

from ...distributed import dtensor_to_tensor, tensor_to_dtensor
from ...utils import ProcessGroupManager, divide_if_divisible
from ..modeling_utils import ParameterizedEmbedding
from .dtensor_module import DTensorModule
from .TP import dtensor_to_tensor, get_module_placements, tensor_to_dtensor
from .TP import get_module_placements


class Embedding_TP(ParameterizedEmbedding, DTensorModule):
Expand Down Expand Up @@ -43,8 +43,8 @@ def __init__(
placement = Replicate()

self.weight = nn.Parameter(
DTensor.from_local(
self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[placement]
tensor_to_dtensor(
self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=placement
)
)

Expand Down
Loading
Loading