diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 6a0bd14d1071..388bbde052d2 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -1,3 +1,4 @@ +import os from typing import Any, Optional, Tuple import numpy as np @@ -23,6 +24,24 @@ def wait(self): self.remain_ops() +def process_group_is_intranode(pg): + if pg is None: + from torch.distributed.distributed_c10d import _get_default_group + + pg = _get_default_group() + + local_world_size = None + for var in ["LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_TASKS_PER_NODE"]: + if var in os.environ: + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + if local_world_size is None: + local_world_size = torch.cuda.device_count() + + group_ranks = dist.get_process_group_ranks(pg) + group_ranks_node_ids = [rank // local_world_size for rank in group_ranks] + return min(group_ranks_node_ids) == max(group_ranks_node_ids) + + def cast_to_fp8( inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -92,7 +111,7 @@ def cast_from_fp8( return ret.to(ret_type) -def all_reduce_fp8( +def _all_reduce_fp8( tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False ) -> Optional[Handle]: r""" @@ -159,7 +178,15 @@ def cat_op(): cat_op() -def all_to_all_single_fp8( +def all_reduce_fp8( + tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False +) -> Optional[Handle]: + # fall back to default op due to performance issue + return dist.all_reduce(tensor, op=op, group=group, async_op=async_op) + + +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def _all_to_all_single_fp8( output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False ) -> Optional[Handle]: r""" @@ -222,6 +249,33 @@ def cast_op(): cast_op() +def all_to_all_single_fp8( + output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False +) -> Optional[Handle]: + r""" + This is wrapper for _all_to_all_single_fp8. + """ + if process_group_is_intranode(group): + return dist.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + else: + return _all_to_all_single_fp8( + output, + input, + fp8_format=fp8_format, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + + def cast_to_fp8_pipeline(inp: Any) -> None: """ Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline. @@ -293,7 +347,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: del inp["dtype"] -def reduce_scatter_fp8( +def _reduce_scatter_fp8( output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False ) -> Optional[Handle]: r""" @@ -338,6 +392,13 @@ def cast_op(): cast_op() +def reduce_scatter_fp8( + output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False +) -> Optional[Handle]: + # fall back to default op due to performance issue + return dist.reduce_scatter(output, input_list, group=group, async_op=async_op) + + def fp8_compress_ddp_grad_comm_hook_async( process_group: dist.ProcessGroup, bucket: dist.GradBucket, @@ -617,10 +678,9 @@ def cast_op(): cast_op() -def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): - +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): world_size = dist.get_world_size(group) - input_type = input_list[0].dtype fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 scale_list = [] @@ -651,6 +711,13 @@ def cast_op(): cast_op() +def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): + if process_group_is_intranode(group): + return dist.all_to_all(output_list, input_list, group=group, async_op=async_op) + else: + return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op) + + def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group)