Skip to content

Commit

Permalink
fix bug in fp8 tp's unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jun 6, 2024
1 parent c4efd94 commit 85cced8
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 86 deletions.
5 changes: 2 additions & 3 deletions src/nanotron/fp8/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8
from nanotron.parallel.parameters import NanotronParameter


def all_reduce(tensor: FP8Tensor, op: dist.ReduceOp, group: dist.ProcessGroup, async_op: bool = False) -> FP8Tensor:
pass
# def all_reduce(tensor: FP8Tensor, op: dist.ReduceOp, group: dist.ProcessGroup, async_op: bool = False) -> FP8Tensor:
# pass


def all_gather(
Expand Down
24 changes: 12 additions & 12 deletions src/nanotron/fp8/old_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,21 +365,21 @@ def forward(
constants.DEBUG_FP8_OUTPUT = output
constants.DEBUG_FP8_OUTPUT_COPY = output.clone(memory_format=torch.preserve_format)

torch.testing.assert_close(output, constants.DEBUG_FP8_OUTPUT_DIRECTLY_FROM_FP8_THAT_WORK, rtol=0.2, atol=0.2)
# torch.testing.assert_close(output, constants.DEBUG_FP8_OUTPUT_DIRECTLY_FROM_FP8_THAT_WORK, rtol=0.2, atol=0.2)

assert output.shape == accum_output.shape
# assert output.shape == accum_output.shape

from nanotron.fp8.tensor import convert_tensor_from_fp8
# # from nanotron.fp8.tensor import convert_tensor_from_fp8

torch.testing.assert_close(
output,
(
constants.DEBUG_FP8_INPUT.to(torch.float16)
@ convert_tensor_from_fp8(weight, weight.fp8_meta, torch.float16).T
),
rtol=0.1,
atol=0.1,
)
# torch.testing.assert_close(
# output,
# (
# constants.DEBUG_FP8_INPUT.to(torch.float16)
# @ convert_tensor_from_fp8(weight, weight.fp8_meta, torch.float16).T
# ),
# rtol=0.1,
# atol=0.1,
# )

# output = _fp8_matmul_kernel_2(
# # NOTE: that works
Expand Down
170 changes: 99 additions & 71 deletions tests/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@
import nanotron.fp8.distributed as dist
import pytest
import torch
import torch.nn.functional as F
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron import constants
from nanotron.distributed import get_global_rank
from nanotron.fp8.constants import FP8LM_RECIPE
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.old_version import fp8_matmul_kernel
from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8
from nanotron.parallel import ParallelContext
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
Expand All @@ -19,6 +15,7 @@
FP8TensorParallelRowLinear,
TensorParallelEmbedding,
)
from nanotron.sanity_checks import assert_tensor_synced_across_pg
from torch import nn


Expand Down Expand Up @@ -173,6 +170,10 @@ def _test_column_linear(
random_input = sharded_random_input
else:
ValueError(f"Unsupported mode: {tp_mode}")

dist.barrier()
assert_tensor_synced_across_pg(random_input, pg=parallel_context.tp_pg)

# It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage
sharded_random_input = sharded_random_input.clone()
sharded_random_input = sharded_random_input.contiguous()
Expand Down Expand Up @@ -301,7 +302,7 @@ def _test_row_linear(
if async_communication:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"

torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cuda.matmul.allow_tf32 = False

# out_features = 3
# in_features_per_rank = 2
Expand All @@ -310,6 +311,7 @@ def _test_row_linear(
in_features_per_rank = 32

in_features = parallel_context.tp_pg.size() * in_features_per_rank
tp_rank = dist.get_rank(parallel_context.tp_pg)

# Sharded
row_linear = FP8TensorParallelRowLinear(
Expand Down Expand Up @@ -339,14 +341,14 @@ def _test_row_linear(
# ]
# )

row_linear.weight.data.set_data(
reference_linear.weight[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
]
)
sharded_weight = reference_linear.weight[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
]
torch.save(sharded_weight, f"sharded_weight_{tp_rank}.pt")
row_linear.weight.data.set_data(sharded_weight)

if with_bias is True:
# broadcast bias from rank 0, and the other don't have bias
Expand Down Expand Up @@ -378,13 +380,17 @@ def _test_row_linear(
# synchronize random_input across tp
dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg)

assert_tensor_synced_across_pg(random_input, pg=parallel_context.tp_pg)

# Row linear receives as input sharded input
random_sharded_input = random_input[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
]
torch.save(random_input, f"random_input_{tp_rank}.pt")
torch.save(random_sharded_input, f"random_sharded_input_tp{tp_rank}.pt")

start_idx = dist.get_rank(parallel_context.tp_pg) * in_features_per_rank
end_idx = (dist.get_rank(parallel_context.tp_pg) + 1) * in_features_per_rank
Expand All @@ -399,52 +405,56 @@ def _test_row_linear(
# Test that we get the same output after forward pass
# TODO @kunhao: We may want to have our custom error type
reference_output = ReferenceLinear.apply(random_input, reference_linear.weight, reference_linear.bias)
local_output = F.linear(
random_sharded_input.to(torch.float16).contiguous(),
reference_linear.weight[sharded_portion].to(torch.float16).contiguous(),
)
torch.save(reference_linear.weight, f"reference_weight_{tp_rank}.pt")
torch.save(reference_output, f"reference_output_{tp_rank}.pt")

fp8_input = FP8Tensor(
random_sharded_input,
dtype=FP8LM_RECIPE.linear.input.dtype,
interval=FP8LM_RECIPE.linear.input.interval,
# is_delayed_scaling=FP8LM_RECIPE.linear.input.is_delayed_scaling,
)
fp8_weight = row_linear.weight.data

constants.DEBUB_FP8_INPUT_THAT_WORK = fp8_input
constants.DEBUB_FP8_WEIGHT_THAT_WORK = fp8_weight

direct_output = fp8_matmul_kernel(
# NOTE: that works
mat_a=fp8_weight,
transpose_a=True,
mat_b=fp8_input,
transpose_b=False,
output=torch.zeros(local_output.shape, dtype=torch.float16, device="cuda"),
use_split_accumulator=FP8LM_RECIPE.linear.split_accumulator.output,
accum_qtype=DTypes.KFLOAT16,
)
constants.DEBUG_FP8_OUTPUT_DIRECTLY_FROM_FP8_THAT_WORK = direct_output
# local_output = F.linear(
# random_sharded_input.to(torch.float16).contiguous(),
# reference_linear.weight[sharded_portion].to(torch.float16).contiguous(),
# )

# fp8_input = FP8Tensor(
# random_sharded_input,
# dtype=FP8LM_RECIPE.linear.input.dtype,
# interval=FP8LM_RECIPE.linear.input.interval,
# # is_delayed_scaling=FP8LM_RECIPE.linear.input.is_delayed_scaling,
# )
# fp8_weight = row_linear.weight.data

# constants.DEBUB_FP8_INPUT_THAT_WORK = fp8_input
# constants.DEBUB_FP8_WEIGHT_THAT_WORK = fp8_weight

# direct_output = fp8_matmul_kernel(
# # NOTE: that works
# mat_a=fp8_weight,
# transpose_a=True,
# mat_b=fp8_input,
# transpose_b=False,
# output=torch.zeros(local_output.shape, dtype=torch.float16, device="cuda"),
# use_split_accumulator=FP8LM_RECIPE.linear.split_accumulator.output,
# accum_qtype=DTypes.KFLOAT16,
# )
# constants.DEBUG_FP8_OUTPUT_DIRECTLY_FROM_FP8_THAT_WORK = direct_output

torch.testing.assert_close(direct_output, local_output, rtol=0.2, atol=0.2)
# torch.testing.assert_close(direct_output, local_output, rtol=0.2, atol=0.2)

reference_output = reference_linear(random_input)
# reference_output = reference_linear(random_input)
sharded_output = row_linear(random_sharded_input)
torch.save(sharded_output, f"sharded_output_{tp_rank}.pt")

torch.cuda.synchronize()
# torch.cuda.synchronize()

# manual_output = F.linear(constants.DEBUG_FP8_INPUT.to(torch.float16), convert_tensor_from_fp8(constants.DEBUG_FP8_WEIGHT, constants.DEBUG_FP8_WEIGHT.fp8_meta, torch.float16))
manual_output = F.linear(
random_sharded_input.to(torch.float16),
convert_tensor_from_fp8(constants.DEBUG_FP8_WEIGHT, constants.DEBUG_FP8_WEIGHT.fp8_meta, torch.float16),
)
constants.REF_MANUAL_OUTPUT = manual_output
# manual_output = F.linear(
# random_sharded_input.to(torch.float16),
# convert_tensor_from_fp8(constants.DEBUG_FP8_WEIGHT, constants.DEBUG_FP8_WEIGHT.fp8_meta, torch.float16),
# )
# constants.REF_MANUAL_OUTPUT = manual_output

# assert sharded_output.dtype == torch.float16
# # NOTE(xrsrke): we expect the output is a raw torch.Tensor, not FP8Paramter, or NanotronParameter
# assert sharded_output.__class__ == torch.Tensor
# assert sharded_output.requires_grad is True
assert sharded_output.dtype == torch.float16
# NOTE(xrsrke): we expect the output is a raw torch.Tensor, not FP8Paramter, or NanotronParameter
assert sharded_output.__class__ == torch.Tensor
assert sharded_output.requires_grad is True

# torch.testing.assert_close(
# row_linear.weight.data.orig_data.to(torch.float16),
Expand All @@ -453,30 +463,27 @@ def _test_row_linear(
# atol=0.1,
# )

torch.testing.assert_close(random_sharded_input, constants.DEBUG_FP8_INPUT, rtol=0.2, atol=0.2)
# torch.testing.assert_close(random_sharded_input, constants.DEBUG_FP8_INPUT, rtol=0.2, atol=0.2)

# local_output = F.linear(random_sharded_input.to(torch.float16), reference_linear.weight[sharded_portion].to(torch.float16))

torch.testing.assert_close(manual_output, local_output, rtol=0.2, atol=0.2)

torch.testing.assert_close(
fp8_matmul_kernel(
# NOTE: that works
mat_a=constants.DEBUG_FP8_WEIGHT,
transpose_a=True,
mat_b=constants.DEBUG_FP8_INPUT_AFTER_QUANT,
transpose_b=False,
output=torch.zeros(manual_output.shape, dtype=torch.float16, device="cuda"),
use_split_accumulator=FP8LM_RECIPE.linear.split_accumulator.output,
accum_qtype=DTypes.KFLOAT16,
),
local_output,
rtol=0.2,
atol=0.2,
)
# torch.testing.assert_close(manual_output, local_output, rtol=0.2, atol=0.2)

torch.testing.assert_close(constants.DEBUG_FP8_OUTPUT, manual_output, rtol=0.2, atol=0.2)
torch.testing.assert_close(constants.DEBUG_FP8_OUTPUT, local_output, rtol=0.2, atol=0.2)
# torch.testing.assert_close(
# fp8_matmul_kernel(
# # NOTE: that works
# mat_a=constants.DEBUG_FP8_WEIGHT,
# transpose_a=True,
# mat_b=constants.DEBUG_FP8_INPUT_AFTER_QUANT,
# transpose_b=False,
# output=torch.zeros(manual_output.shape, dtype=torch.float16, device="cuda"),
# use_split_accumulator=FP8LM_RECIPE.linear.split_accumulator.output,
# accum_qtype=DTypes.KFLOAT16,
# ),
# local_output,
# rtol=0.2,
# atol=0.2,
# )

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
sharded_reference_output = reference_output
Expand All @@ -491,6 +498,27 @@ def _test_row_linear(
else:
raise ValueError(f"Unsupported mode: {tp_mode}")

# torch.testing.assert_close(
# fp8_matmul_kernel(
# # NOTE: that works
# mat_a=constants.DEBUG_FP8_WEIGHT,
# transpose_a=True,
# mat_b=constants.DEBUG_FP8_INPUT_AFTER_QUANT,
# transpose_b=False,
# output=torch.zeros(reference_output.shape, dtype=torch.float16, device="cuda"),
# use_split_accumulator=FP8LM_RECIPE.linear.split_accumulator.output,
# accum_qtype=DTypes.KFLOAT16,
# ),
# reference_output,
# rtol=0.2,
# atol=0.2,
# )

# torch.testing.assert_close(constants.DEBUG_FP8_OUTPUT_COPY, reference_output.to(torch.float16), rtol=0.2, atol=0.2)
torch.save(constants.DEBUG_FP8_OUTPUT_COPY, f"sharded_output_before_allreduce_{tp_rank}.pt")
# torch.testing.assert_close(sharded_output, sharded_reference_output, rtol=0.2, atol=0.2)
# torch.testing.assert_close(constants.DEBUG_FP8_OUTPUT, local_output, rtol=0.2, atol=0.2)

# TODO @thomasw21: Tune tolerance
torch.testing.assert_close(sharded_output, sharded_reference_output.to(torch.float16), rtol=0.2, atol=0.2)

Expand Down

0 comments on commit 85cced8

Please sign in to comment.