From c937375543bb482cd1f052b1a7083ec64a4dd07d Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 1 Nov 2024 15:34:19 +0000 Subject: [PATCH] add bencmark speed with 5% speed up --- benchmark/fp8_tp_speed.py | 152 +++++++++++++++++++ src/nanotron/fp8/functional.py | 2 +- src/nanotron/fp8/linear.py | 41 +---- src/nanotron/parallel/tensor_parallel/nn.py | 157 ++------------------ 4 files changed, 169 insertions(+), 183 deletions(-) create mode 100644 benchmark/fp8_tp_speed.py diff --git a/benchmark/fp8_tp_speed.py b/benchmark/fp8_tp_speed.py new file mode 100644 index 00000000..d0ae3e63 --- /dev/null +++ b/benchmark/fp8_tp_speed.py @@ -0,0 +1,152 @@ +import argparse +import itertools + +import pandas as pd +import torch +from nanotron.models.base import init_on_device_and_dtype +from nanotron.parallel import ParallelContext +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import FP8TensorParallelColumnLinear, TensorParallelColumnLinear +from torch.utils import benchmark + +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ +h100_peak_flops_float32 = 67e12 +h100_peak_flops_fp16_tc = 989e12 +h100_peak_tops_float8_tc = 1979e12 + +dtype_to_peak_tops = { + torch.float32: h100_peak_flops_float32, + torch.float16: h100_peak_flops_fp16_tc, + torch.bfloat16: h100_peak_flops_fp16_tc, + torch.float8_e4m3fn: h100_peak_tops_float8_tc, + torch.float8_e5m2: h100_peak_tops_float8_tc, +} + + +def benchmark_fn_in_sec(f, *args, **kwargs): + # Manual warmup + for _ in range(4): + f(*args, **kwargs) + + t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}) + measurement = t0.blocked_autorange() + return measurement.mean + + +def run_fp8_linear(input, M, N, K, parallel_context): + # input = torch.randn(M, K, device="cuda", requires_grad=False) + column_linear = FP8TensorParallelColumnLinear( + in_features=K, + out_features=N, + pg=parallel_context.tp_pg, + mode=TensorParallelLinearMode.ALL_REDUCE, + device="cuda", + async_communication=False, + bias=False, + ) + + sharded_output = column_linear(input) + sharded_output.sum().backward() + + return sharded_output + + +def run_linear(input, M, N, K, parallel_context): + # input = torch.randn(M, K, device="cuda", requires_grad=False) + with init_on_device_and_dtype(device="cuda", dtype=torch.bfloat16): + column_linear = TensorParallelColumnLinear( + in_features=K, + out_features=N, + pg=parallel_context.tp_pg, + mode=TensorParallelLinearMode.ALL_REDUCE, + device="cuda", + async_communication=False, + bias=False, + ) + + sharded_output = column_linear(input) + sharded_output.sum().backward() + assert sharded_output.dtype == torch.bfloat16, f"Expected bfloat16, got {sharded_output.dtype}" + return sharded_output + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions") + parser.add_argument("--exp_number", type=str, help="Experiment number") + parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") + parser.add_argument( + "--dimensions", + type=str, + default="1024,2048,4096,8192,16384,32768", + help="Comma-separated list of dimensions to test", + ) + return parser.parse_args() + + +def benchmark_linear_operations(M, N, K, parallel_context): + input = torch.randn(M, K, device="cuda", requires_grad=False) + bfloat16_input = torch.randn(M, K, device="cuda", requires_grad=False, dtype=torch.bfloat16) + + # Benchmark FP8 + fp8_time = benchmark_fn_in_sec(run_fp8_linear, input, M, N, K, parallel_context) + + # Benchmark BFloat16 + bfloat16_time = benchmark_fn_in_sec(run_linear, bfloat16_input, M, N, K, parallel_context) + + # Calculate FLOPS + # Each linear operation performs 2*M*N*K FLOPs (multiply-add) + total_flops = 2 * M * N * K + + fp8_tflops = (total_flops / fp8_time) / 1e12 + bfloat16_tflops = (total_flops / bfloat16_time) / 1e12 + + # Calculate efficiency compared to peak performance + fp8_efficiency = (fp8_tflops / (h100_peak_tops_float8_tc / 1e12)) * 100 + bfloat16_efficiency = (bfloat16_tflops / (h100_peak_flops_fp16_tc / 1e12)) * 100 + + return { + "M": M, + "N": N, + "K": K, + "FP8_time_ms": fp8_time * 1000, + "BF16_time_ms": bfloat16_time * 1000, + "FP8_TFLOPS": fp8_tflops, + "BF16_TFLOPS": bfloat16_tflops, + "FP8_efficiency_%": fp8_efficiency, + "BF16_efficiency_%": bfloat16_efficiency, + "Speedup": bfloat16_time / fp8_time, + } + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + + args = parse_args() + + dimensions = [int(d.strip()) for d in args.dimensions.split(",")] + TP_SIZE = args.tp_size + EXP_NUMBER = args.exp_number + + results = [] + total = len(list(itertools.product(dimensions, dimensions, dimensions))) + experiment_count = 0 + parallel_context = ParallelContext(data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=TP_SIZE) + + # Run benchmarks and collect results + results = [] + i = 0 + for M, N, K in itertools.product(dimensions, dimensions, dimensions): + i += 1 + result = benchmark_linear_operations(M, N, K, parallel_context) + results.append(result) + print(f"Experiment {i}/{total} complete") + + # Create DataFrame + df = pd.DataFrame(results) + df = df.round(2) # Round to 2 decimal places + + # Sort by matrix size for better readability + df = df.sort_values(by=["M", "N", "K"]) + + print("\nBenchmark Results:") + print(df.to_string(index=False)) diff --git a/src/nanotron/fp8/functional.py b/src/nanotron/fp8/functional.py index ed07326e..118c9d98 100644 --- a/src/nanotron/fp8/functional.py +++ b/src/nanotron/fp8/functional.py @@ -72,7 +72,7 @@ def linear( # because weight and bias's requires_grad are set to False # so that we can compute the gradients using the fp8 kernels by ourselves phony = torch.empty(0, device=input.device, requires_grad=True) - output = torch.zeros(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype) + output = torch.empty(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype) output, _ = _FP8Matmul.apply(input, weight, output, phony, metadatas, recipe, name) # TODO(xrsrke): add support for adding bias in fp8 diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py index 8a7ef194..e0d3668b 100644 --- a/src/nanotron/fp8/linear.py +++ b/src/nanotron/fp8/linear.py @@ -51,12 +51,7 @@ def __init__( # TODO(xrsrke): take initialization dtype from recipe # NOTE: initialize in float32 super().__init__(in_features, out_features, bias, device, dtype=torch.float32) - # TODO(xrsrke): don't fixed dtype, take it from the FP8 recipe - # DTypes.FP8E4M3 - weight_data = self.weight.data - # orig_w_shape = weight_data.shape - # weight_data = weight_data.contiguous().view(-1).contiguous().reshape(orig_w_shape) - quant_w = FP8Parameter(weight_data, dtype=recipe.weight.dtype, interval=recipe.weight.interval) + quant_w = FP8Parameter(self.weight.data, dtype=recipe.weight.dtype, interval=recipe.weight.interval) assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}" self.weight = quant_w @@ -65,6 +60,7 @@ def __init__( if self.bias is not None: self.bias = nn.Parameter(self.bias.to(recipe.accum_dtype)) assert self.bias.dtype == recipe.accum_dtype + self.metadatas = FP8LinearMeta() self.recipe = recipe @@ -93,7 +89,6 @@ def forward( output: torch.Tensor, phony: torch.Tensor, metadatas: FP8LinearMeta, - # accum_qtype: DTypes, recipe: FP8LinearRecipe, name, ) -> torch.Tensor: @@ -102,8 +97,6 @@ def forward( from nanotron import constants from nanotron.config.fp8_config import FP8Args - # dist.monitored_barrier(wait_all_ranks=True) - if constants.CONFIG is None: fp8_config = FP8Args() else: @@ -111,11 +104,6 @@ def forward( sync_amax_in_input = fp8_config.sync_amax_in_input - # orig_input_shape = input.shape - # input = input.contiguous().view(-1).contiguous().view(orig_input_shape) - - # input = input.contiguous() - if metadatas.input is None: fp8_input = FP8Tensor( input, dtype=recipe.input.dtype, interval=recipe.input.interval, sync=sync_amax_in_input @@ -129,15 +117,7 @@ def forward( ctx.name = name ctx.recipe = recipe - # accum_output = output.contiguous() accum_output = output - # accum_output = torch.zeros(output.shape, dtype=torch.float16, device="cuda") - - # assert fp8_input.data.is_contiguous() - # assert weight.data.is_contiguous() - # assert accum_output.is_contiguous() - - # dist.monitored_barrier(wait_all_ranks=True) output = fp8_matmul_kernel( # NOTE: that works @@ -190,7 +170,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ fp8_input = cast(FP8Tensor, fp8_input) fp8_weight = cast(FP8Tensor, fp8_weight) - grad_output = grad_output.contiguous() + # grad_output = grad_output.contiguous() ctx.metadatas = cast(FP8LinearMeta, ctx.metadatas) if ctx.metadatas.input_grad is None: @@ -206,7 +186,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ transposed_fp8_weight = fp8_weight.transpose_fp8() - grad_input_temp = torch.zeros( + grad_input_temp = torch.empty( fp8_grad_output.shape[0], transposed_fp8_weight.shape[0], device="cuda", @@ -229,7 +209,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ transposed_fp8_grad_output = fp8_grad_output.transpose_fp8() transposed_fp8_input = fp8_input.transpose_fp8() - grad_weight_temp = torch.zeros( + grad_weight_temp = torch.empty( transposed_fp8_input.shape[0], transposed_fp8_grad_output.shape[0], device="cuda", @@ -251,16 +231,6 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ assert grad_weight.dtype == recipe.accum_dtype # TODO(xrsrke): maintain a persistence metadata across training - # grad_weight = grad_weight.T.contiguous() - # orig_shape = grad_weight.shape - # grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape) - - # grad_weight = grad_weight.T - # orig_shape = grad_weight.shape - # grad_weight = grad_weight.t().view(-1).reshape(orig_shape) - - # NOTE: works - # grad_weight = grad_weight.reshape(grad_weight.T.shape) grad_weight = grad_weight.reshape(grad_weight.shape[::-1]) # NOTE: if use gradient accumulation, then directly keep the high precision weights for later accumulate @@ -299,5 +269,4 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ # NOTE: sanity check assert isinstance(fp8_weight.grad, FP8Tensor) - # return grad_input.contiguous(), None, None, None, None, None, None return grad_input, None, None, None, None, None, None diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 89287169..920dc403 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -41,142 +41,10 @@ row_linear, ) from nanotron.parallel.tied_parameters import create_tied_parameter -from nanotron.utils import post_init - -# class TensorParallelColumnLinear(nn.Linear): -# def __init__( -# self, -# in_features, -# out_features, -# pg: dist.ProcessGroup, -# mode: TensorParallelLinearMode, -# bias=True, -# device=None, -# dtype=None, -# async_communication: bool = False, -# contiguous_chunks: Optional[Tuple[int, ...]] = None, -# tp_recompute_allgather: bool = True, -# ): -# self.pg = pg -# self.world_size = pg.size() - -# assert out_features % self.world_size == 0 - -# self.in_features = in_features -# self.out_features = out_features // self.world_size -# self.tp_recompute_allgather = tp_recompute_allgather - -# super().__init__( -# in_features=self.in_features, -# out_features=self.out_features, -# bias=bias, -# device=device, -# dtype=dtype, -# ) - -# self.mode = mode -# self.async_communication = async_communication - -# if contiguous_chunks is not None: -# assert ( -# sum(contiguous_chunks) == out_features -# ), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to out_features ({out_features})" -# split_config = SplitConfig(split_dim=0, contiguous_chunks=contiguous_chunks) - -# mark_all_parameters_in_module_as_sharded( -# self, -# pg=self.pg, -# split_config=split_config, -# ) - -# def forward(self, x: torch.Tensor) -> torch.Tensor: -# return column_linear( -# input=x, -# weight=self.weight, -# bias=self.bias, -# group=self.pg, -# tp_mode=self.mode, -# async_communication=self.async_communication, -# tp_recompute_allgather=self.tp_recompute_allgather, -# ) - -# def extra_repr(self) -> str: -# return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}" - - -# class TensorParallelRowLinear(nn.Linear): -# def __init__( -# self, -# in_features, -# out_features, -# pg: dist.ProcessGroup, -# mode: TensorParallelLinearMode, -# bias=True, -# device=None, -# dtype=None, -# async_communication: bool = False, -# contiguous_chunks: Optional[Tuple[int, ...]] = None, -# ): -# self.pg = pg -# self.world_size = pg.size() - -# assert in_features % self.world_size == 0 - -# self.in_features = in_features // self.world_size -# self.out_features = out_features - -# # No need to shard the bias term, only rank 0 would have it -# bias = dist.get_rank(self.pg) == 0 and bias - -# super().__init__( -# in_features=self.in_features, -# out_features=self.out_features, -# bias=bias, -# device=device, -# dtype=dtype, -# ) -# self.mode = mode -# self.async_communication = async_communication -# if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication: -# raise ValueError("async_communication is not supported for ALL_REDUCE mode") - -# if contiguous_chunks is not None: -# assert ( -# sum(contiguous_chunks) == in_features -# ), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to in_features ({in_features})" - -# split_config = SplitConfig(split_dim=1, contiguous_chunks=contiguous_chunks) - -# self._mark_all_parameters_in_module_as_sharded(split_config) - -# def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): -# for name, param in list(self.named_parameters()): -# if name == "bias": -# # `bias` only exists in rank 0 because it's not sharded -# new_param = NanotronParameter(tensor=param) -# else: -# new_param = create_sharded_parameter_from_config( -# parameter=param, -# pg=self.pg, -# split_config=split_config, -# ) -# setattr(self, name, new_param) - -# def forward(self, x: torch.Tensor) -> torch.Tensor: -# return row_linear( -# input=x, -# weight=self.weight, -# bias=self.bias, -# group=self.pg, -# tp_mode=self.mode, -# async_communication=self.async_communication, -# ) - -# def extra_repr(self) -> str: -# return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_in_features={self.in_features * self.world_size}" - - -@post_init + +# from nanotron.utils import post_init + +# @post_init class _BaseTensorParallelColumnLinear: def __init__( self, @@ -184,9 +52,9 @@ def __init__( out_features, pg: dist.ProcessGroup, mode: TensorParallelLinearMode, - bias=True, - device=None, - dtype=None, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: torch.dtype = None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, name: Optional[str] = None, @@ -201,9 +69,6 @@ def __init__( self.out_features = out_features // self.world_size self.name = name - if name == "model.lm_head": - assert 1 == 1 - init_args = { "in_features": self.in_features, "out_features": self.out_features, @@ -245,7 +110,7 @@ def extra_repr(self) -> str: return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}" -@post_init +# @post_init class _BaseTensorParallelRowLinear: def __init__( self, @@ -253,9 +118,9 @@ def __init__( out_features, pg: dist.ProcessGroup, mode: TensorParallelLinearMode, - bias=True, - device=None, - dtype=None, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: torch.dtype = None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, name: Optional[str] = None,