Skip to content

Commit

Permalink
[fp8] disable all_to_all_fp8 in intranode (hpcaitech#6045)
Browse files Browse the repository at this point in the history
* enhance all_to_all_fp8 with internode comm control

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable some fp8 ops due to performance issue

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
BurkeHulk and pre-commit-ci[bot] authored Sep 9, 2024
1 parent 26e5539 commit 5ce6dd7
Showing 1 changed file with 73 additions and 6 deletions.
79 changes: 73 additions & 6 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Any, Optional, Tuple

import numpy as np
Expand All @@ -23,6 +24,24 @@ def wait(self):
self.remain_ops()


def process_group_is_intranode(pg):
if pg is None:
from torch.distributed.distributed_c10d import _get_default_group

pg = _get_default_group()

local_world_size = None
for var in ["LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_TASKS_PER_NODE"]:
if var in os.environ:
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
if local_world_size is None:
local_world_size = torch.cuda.device_count()

group_ranks = dist.get_process_group_ranks(pg)
group_ranks_node_ids = [rank // local_world_size for rank in group_ranks]
return min(group_ranks_node_ids) == max(group_ranks_node_ids)


def cast_to_fp8(
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -92,7 +111,7 @@ def cast_from_fp8(
return ret.to(ret_type)


def all_reduce_fp8(
def _all_reduce_fp8(
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
) -> Optional[Handle]:
r"""
Expand Down Expand Up @@ -159,7 +178,15 @@ def cat_op():
cat_op()


def all_to_all_single_fp8(
def all_reduce_fp8(
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
) -> Optional[Handle]:
# fall back to default op due to performance issue
return dist.all_reduce(tensor, op=op, group=group, async_op=async_op)


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def _all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> Optional[Handle]:
r"""
Expand Down Expand Up @@ -222,6 +249,33 @@ def cast_op():
cast_op()


def all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> Optional[Handle]:
r"""
This is wrapper for _all_to_all_single_fp8.
"""
if process_group_is_intranode(group):
return dist.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=async_op,
)
else:
return _all_to_all_single_fp8(
output,
input,
fp8_format=fp8_format,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=async_op,
)


def cast_to_fp8_pipeline(inp: Any) -> None:
"""
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
Expand Down Expand Up @@ -293,7 +347,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
del inp["dtype"]


def reduce_scatter_fp8(
def _reduce_scatter_fp8(
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
r"""
Expand Down Expand Up @@ -338,6 +392,13 @@ def cast_op():
cast_op()


def reduce_scatter_fp8(
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
# fall back to default op due to performance issue
return dist.reduce_scatter(output, input_list, group=group, async_op=async_op)


def fp8_compress_ddp_grad_comm_hook_async(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
Expand Down Expand Up @@ -617,10 +678,9 @@ def cast_op():
cast_op()


def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):

@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
world_size = dist.get_world_size(group)

input_type = input_list[0].dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
scale_list = []
Expand Down Expand Up @@ -651,6 +711,13 @@ def cast_op():
cast_op()


def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
if process_group_is_intranode(group):
return dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
else:
return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op)


def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:

world_size = dist.get_world_size(group)
Expand Down

0 comments on commit 5ce6dd7

Please sign in to comment.