Skip to content

Commit

Permalink
add bencmark speed with 5% speed up
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 1, 2024
1 parent c827594 commit c937375
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 183 deletions.
152 changes: 152 additions & 0 deletions benchmark/fp8_tp_speed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import argparse
import itertools

import pandas as pd
import torch
from nanotron.models.base import init_on_device_and_dtype
from nanotron.parallel import ParallelContext
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import FP8TensorParallelColumnLinear, TensorParallelColumnLinear
from torch.utils import benchmark

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 989e12
h100_peak_tops_float8_tc = 1979e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}


def benchmark_fn_in_sec(f, *args, **kwargs):
# Manual warmup
for _ in range(4):
f(*args, **kwargs)

t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f})
measurement = t0.blocked_autorange()
return measurement.mean


def run_fp8_linear(input, M, N, K, parallel_context):
# input = torch.randn(M, K, device="cuda", requires_grad=False)
column_linear = FP8TensorParallelColumnLinear(
in_features=K,
out_features=N,
pg=parallel_context.tp_pg,
mode=TensorParallelLinearMode.ALL_REDUCE,
device="cuda",
async_communication=False,
bias=False,
)

sharded_output = column_linear(input)
sharded_output.sum().backward()

return sharded_output


def run_linear(input, M, N, K, parallel_context):
# input = torch.randn(M, K, device="cuda", requires_grad=False)
with init_on_device_and_dtype(device="cuda", dtype=torch.bfloat16):
column_linear = TensorParallelColumnLinear(
in_features=K,
out_features=N,
pg=parallel_context.tp_pg,
mode=TensorParallelLinearMode.ALL_REDUCE,
device="cuda",
async_communication=False,
bias=False,
)

sharded_output = column_linear(input)
sharded_output.sum().backward()
assert sharded_output.dtype == torch.bfloat16, f"Expected bfloat16, got {sharded_output.dtype}"
return sharded_output


def parse_args():
parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions")
parser.add_argument("--exp_number", type=str, help="Experiment number")
parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size")
parser.add_argument(
"--dimensions",
type=str,
default="1024,2048,4096,8192,16384,32768",
help="Comma-separated list of dimensions to test",
)
return parser.parse_args()


def benchmark_linear_operations(M, N, K, parallel_context):
input = torch.randn(M, K, device="cuda", requires_grad=False)
bfloat16_input = torch.randn(M, K, device="cuda", requires_grad=False, dtype=torch.bfloat16)

# Benchmark FP8
fp8_time = benchmark_fn_in_sec(run_fp8_linear, input, M, N, K, parallel_context)

# Benchmark BFloat16
bfloat16_time = benchmark_fn_in_sec(run_linear, bfloat16_input, M, N, K, parallel_context)

# Calculate FLOPS
# Each linear operation performs 2*M*N*K FLOPs (multiply-add)
total_flops = 2 * M * N * K

fp8_tflops = (total_flops / fp8_time) / 1e12
bfloat16_tflops = (total_flops / bfloat16_time) / 1e12

# Calculate efficiency compared to peak performance
fp8_efficiency = (fp8_tflops / (h100_peak_tops_float8_tc / 1e12)) * 100
bfloat16_efficiency = (bfloat16_tflops / (h100_peak_flops_fp16_tc / 1e12)) * 100

return {
"M": M,
"N": N,
"K": K,
"FP8_time_ms": fp8_time * 1000,
"BF16_time_ms": bfloat16_time * 1000,
"FP8_TFLOPS": fp8_tflops,
"BF16_TFLOPS": bfloat16_tflops,
"FP8_efficiency_%": fp8_efficiency,
"BF16_efficiency_%": bfloat16_efficiency,
"Speedup": bfloat16_time / fp8_time,
}


if __name__ == "__main__":
torch.backends.cudnn.benchmark = True

args = parse_args()

dimensions = [int(d.strip()) for d in args.dimensions.split(",")]
TP_SIZE = args.tp_size
EXP_NUMBER = args.exp_number

results = []
total = len(list(itertools.product(dimensions, dimensions, dimensions)))
experiment_count = 0
parallel_context = ParallelContext(data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=TP_SIZE)

# Run benchmarks and collect results
results = []
i = 0
for M, N, K in itertools.product(dimensions, dimensions, dimensions):
i += 1
result = benchmark_linear_operations(M, N, K, parallel_context)
results.append(result)
print(f"Experiment {i}/{total} complete")

# Create DataFrame
df = pd.DataFrame(results)
df = df.round(2) # Round to 2 decimal places

# Sort by matrix size for better readability
df = df.sort_values(by=["M", "N", "K"])

print("\nBenchmark Results:")
print(df.to_string(index=False))
2 changes: 1 addition & 1 deletion src/nanotron/fp8/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def linear(
# because weight and bias's requires_grad are set to False
# so that we can compute the gradients using the fp8 kernels by ourselves
phony = torch.empty(0, device=input.device, requires_grad=True)
output = torch.zeros(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype)
output = torch.empty(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype)
output, _ = _FP8Matmul.apply(input, weight, output, phony, metadatas, recipe, name)

# TODO(xrsrke): add support for adding bias in fp8
Expand Down
41 changes: 5 additions & 36 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,7 @@ def __init__(
# TODO(xrsrke): take initialization dtype from recipe
# NOTE: initialize in float32
super().__init__(in_features, out_features, bias, device, dtype=torch.float32)
# TODO(xrsrke): don't fixed dtype, take it from the FP8 recipe
# DTypes.FP8E4M3
weight_data = self.weight.data
# orig_w_shape = weight_data.shape
# weight_data = weight_data.contiguous().view(-1).contiguous().reshape(orig_w_shape)
quant_w = FP8Parameter(weight_data, dtype=recipe.weight.dtype, interval=recipe.weight.interval)
quant_w = FP8Parameter(self.weight.data, dtype=recipe.weight.dtype, interval=recipe.weight.interval)
assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}"
self.weight = quant_w

Expand All @@ -65,6 +60,7 @@ def __init__(
if self.bias is not None:
self.bias = nn.Parameter(self.bias.to(recipe.accum_dtype))
assert self.bias.dtype == recipe.accum_dtype

self.metadatas = FP8LinearMeta()
self.recipe = recipe

Expand Down Expand Up @@ -93,7 +89,6 @@ def forward(
output: torch.Tensor,
phony: torch.Tensor,
metadatas: FP8LinearMeta,
# accum_qtype: DTypes,
recipe: FP8LinearRecipe,
name,
) -> torch.Tensor:
Expand All @@ -102,20 +97,13 @@ def forward(
from nanotron import constants
from nanotron.config.fp8_config import FP8Args

# dist.monitored_barrier(wait_all_ranks=True)

if constants.CONFIG is None:
fp8_config = FP8Args()
else:
fp8_config = cast(FP8Args, constants.CONFIG.fp8)

sync_amax_in_input = fp8_config.sync_amax_in_input

# orig_input_shape = input.shape
# input = input.contiguous().view(-1).contiguous().view(orig_input_shape)

# input = input.contiguous()

if metadatas.input is None:
fp8_input = FP8Tensor(
input, dtype=recipe.input.dtype, interval=recipe.input.interval, sync=sync_amax_in_input
Expand All @@ -129,15 +117,7 @@ def forward(
ctx.name = name
ctx.recipe = recipe

# accum_output = output.contiguous()
accum_output = output
# accum_output = torch.zeros(output.shape, dtype=torch.float16, device="cuda")

# assert fp8_input.data.is_contiguous()
# assert weight.data.is_contiguous()
# assert accum_output.is_contiguous()

# dist.monitored_barrier(wait_all_ranks=True)

output = fp8_matmul_kernel(
# NOTE: that works
Expand Down Expand Up @@ -190,7 +170,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[

fp8_input = cast(FP8Tensor, fp8_input)
fp8_weight = cast(FP8Tensor, fp8_weight)
grad_output = grad_output.contiguous()
# grad_output = grad_output.contiguous()

ctx.metadatas = cast(FP8LinearMeta, ctx.metadatas)
if ctx.metadatas.input_grad is None:
Expand All @@ -206,7 +186,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[

transposed_fp8_weight = fp8_weight.transpose_fp8()

grad_input_temp = torch.zeros(
grad_input_temp = torch.empty(
fp8_grad_output.shape[0],
transposed_fp8_weight.shape[0],
device="cuda",
Expand All @@ -229,7 +209,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
transposed_fp8_grad_output = fp8_grad_output.transpose_fp8()
transposed_fp8_input = fp8_input.transpose_fp8()

grad_weight_temp = torch.zeros(
grad_weight_temp = torch.empty(
transposed_fp8_input.shape[0],
transposed_fp8_grad_output.shape[0],
device="cuda",
Expand All @@ -251,16 +231,6 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
assert grad_weight.dtype == recipe.accum_dtype
# TODO(xrsrke): maintain a persistence metadata across training

# grad_weight = grad_weight.T.contiguous()
# orig_shape = grad_weight.shape
# grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape)

# grad_weight = grad_weight.T
# orig_shape = grad_weight.shape
# grad_weight = grad_weight.t().view(-1).reshape(orig_shape)

# NOTE: works
# grad_weight = grad_weight.reshape(grad_weight.T.shape)
grad_weight = grad_weight.reshape(grad_weight.shape[::-1])

# NOTE: if use gradient accumulation, then directly keep the high precision weights for later accumulate
Expand Down Expand Up @@ -299,5 +269,4 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
# NOTE: sanity check
assert isinstance(fp8_weight.grad, FP8Tensor)

# return grad_input.contiguous(), None, None, None, None, None, None
return grad_input, None, None, None, None, None, None
Loading

0 comments on commit c937375

Please sign in to comment.