Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support power of 2 scaling factors in float8 training and use e4m3 everywhere #1670

Merged
merged 16 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def test_transpose(self):

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("axiswise_dim", [0, -1])
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
@pytest.mark.parametrize("power_of_2_scale", [True, False])
def test_axiswise_dynamic_cast(self, shape, axiswise_dim, power_of_2_scale):
a = torch.randn(*shape, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
Expand All @@ -173,6 +174,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=axiswise_dim,
power_of_2_scale=power_of_2_scale,
)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
Expand Down
18 changes: 12 additions & 6 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@
hp_tensor_to_float8_delayed,
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
from torchao.float8.float8_utils import config_has_stateful_scaling
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
from torchao.testing.float8.test_utils import get_test_float8_linear_config
Expand Down Expand Up @@ -420,13 +416,21 @@ def test_sync_amax_func_cuda_graph_success():
torch.float16,
],
)
def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
@pytest.mark.parametrize(
"power_of_2_scale",
[
True,
False,
],
)
def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool):
scaling_type_weight = ScalingType.DYNAMIC
torch.manual_seed(42)
hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype)
hp_tensor2 = hp_tensor1.detach().clone()
float8_config = Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
power_of_2_scale=power_of_2_scale,
)
linear_mm_config = LinearMMConfig(
# output
Expand Down Expand Up @@ -456,13 +460,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
power_of_2_scale=float8_config.power_of_2_scale,
)
torch._dynamo.reset()
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor2,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
power_of_2_scale=float8_config.power_of_2_scale,
)
assert torch.equal(float8_eager._scale, float8_compile._scale)
assert torch.equal(float8_eager._data, float8_compile._data)
Expand Down
9 changes: 9 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ class Float8LinearConfig:
# tests so that the warning does not spam the CI stdout.
force_recompute_fp8_weight_in_bwd: bool = False

# If this option is enabled, the scaling factor used for float8 quantization
# will be rounded down to the nearest power of 2. This has been shown to help
# reduce quantization error by avoiding rounding errors when multiplying/dividing
# by the scaling factor, as well as ensuring large values are quantized to the
# same value in the forward pass as the backward passes.
power_of_2_scale: bool = False
danielvegamyhre marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self):
# Populate the additional cast overrides, if the user did not specify them
# Note: this hacks around the frozen-ness of this dataclass
Expand Down Expand Up @@ -336,6 +343,8 @@ def recipe_name_to_linear_config(
cast_config_input=cc_i,
cast_config_weight=cc_w,
cast_config_grad_output=cc_go,
# enable power of 2 scaling factors by default for row-wise scaling
power_of_2_scale=True,
)

elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:
Expand Down
6 changes: 6 additions & 0 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def forward(
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_input.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

if tensor_already_casted_to_fp8(weight_hp_t):
Expand All @@ -112,6 +113,7 @@ def forward(
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_weight.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

# the reshapes are needed in order to make the shapes compatible with
Expand Down Expand Up @@ -151,6 +153,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_grad_output.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

if tensor_already_casted_to_fp8(weight_hp_t):
Expand Down Expand Up @@ -181,6 +184,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_weight_for_grad_input.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

grad_input = torch.mm(
Expand Down Expand Up @@ -216,6 +220,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_grad_output_for_grad_weight.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

if tensor_already_casted_to_fp8(input_hp_reshaped):
Expand All @@ -233,6 +238,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_input_for_grad_weight.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

grad_weight = torch.mm(
Expand Down
3 changes: 3 additions & 0 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic(
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
power_of_2_scale: bool = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewer: this param list is getting pretty long, and 4 of the 9 params can be derived from the Float8LinearConfig. Any thoughts on refactoring to pass in the Float8LinearConfig directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I'll do that in a follow up so Less can begin scale testing after we merge this asap

) -> Float8Tensor:
"""
Given a high precision tensor `hp_tensor`,
Expand All @@ -51,6 +52,7 @@ def hp_tensor_to_float8_dynamic(
the 3 fwd/bwd gemms of linear
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
power_of_2_scale: if true, round scaling factor down to the nearest power of 2.
"""
scale = tensor_to_scale(
hp_tensor,
Expand All @@ -59,6 +61,7 @@ def hp_tensor_to_float8_dynamic(
device_mesh,
scaling_granularity,
axiswise_dim,
power_of_2_scale,
)
return hp_tensor_and_scale_to_float8(
hp_tensor,
Expand Down
33 changes: 23 additions & 10 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
import torch.distributed as dist
from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce

from torchao.float8.config import (
Float8LinearConfig,
ScalingGranularity,
ScalingType,
)
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
Expand All @@ -33,11 +29,14 @@


@torch.no_grad()
def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, power_of_2_scale: bool = False
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
power_of_2_scale: if true, round scaling factor down to the nearest power of 2.
"""
# torch.compile and eager show different numerics for 1.0 / float32,
# upcast to float64 to ensure same numeric between compile and eager
Expand All @@ -46,7 +45,9 @@ def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

if power_of_2_scale:
# rounds down to the nearest power of 2.
res = torch.exp2(torch.floor(torch.log2(res)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be the same as setting the mantissa to all-zeroes (maybe with some special handling for inf/nan), and can be implemented with bit shifting. Do you want to try to see if that resolves the regression?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't test, but something like

for float32

res = res.view(torch.uint32_t)
res = (res >> 23) << 23
res = res.view(torch.float)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint32 doesn't support bitshift ops apparently so I had to use int32. unit tests pass though and TPS regression is gone. will the sign bit affect anything? I did some manual tests in the interpreter and rounding seemed to work as expecting.

[rank0]:2025-02-05 16:11:30,663 - root - INFO - step:  1  loss:  8.2105  memory:  9.69GiB(10.20%)  tps: 610  mfu: 0.33%
[rank0]:2025-02-05 16:11:30,663 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 16:11:30,896 - root - INFO - step:  2  loss:  9.2258  memory: 11.02GiB(11.60%)  tps: 70,207  mfu: 37.73%
[rank0]:2025-02-05 16:11:31,129 - root - INFO - step:  3  loss:  8.5120  memory: 11.02GiB(11.60%)  tps: 70,377  mfu: 37.82%
[rank0]:2025-02-05 16:11:31,361 - root - INFO - step:  4  loss: 11.7253  memory: 11.02GiB(11.60%)  tps: 70,885  mfu: 38.10%
[rank0]:2025-02-05 16:11:31,591 - root - INFO - step:  5  loss:  9.3686  memory: 11.02GiB(11.60%)  tps: 71,365  mfu: 38.35%
[rank0]:2025-02-05 16:11:31,823 - root - INFO - step:  6  loss:  8.5610  memory: 11.02GiB(11.60%)  tps: 70,634  mfu: 37.96%
[rank0]:2025-02-05 16:11:32,059 - root - INFO - step:  7  loss:  7.7763  memory: 11.02GiB(11.60%)  tps: 69,681  mfu: 37.45%
[rank0]:2025-02-05 16:11:32,287 - root - INFO - step:  8  loss:  7.4649  memory: 11.02GiB(11.60%)  tps: 71,963  mfu: 38.68%
[rank0]:2025-02-05 16:11:32,517 - root - INFO - step:  9  loss:  7.2956  memory: 11.02GiB(11.60%)  tps: 71,188  mfu: 38.26%
[rank0]:2025-02-05 16:11:32,749 - root - INFO - step: 10  loss:  7.1085  memory: 11.02GiB(11.60%)  tps: 70,748  mfu: 38.02%```

return res.to(torch.float32)


Expand Down Expand Up @@ -119,21 +120,33 @@ def tensor_to_amax(

@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor,
hp_tensor: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
power_of_2_scale: bool = False,
) -> torch.Tensor:
"""
Compute scaling factor for the given high precision tensor.

Args:
hp_tensor: high precision tensor
float8_dtype: the float8 dtype to use
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
power_of_2_scale: if true, round scaling factor down to the nearest power of 2.
"""
amax = tensor_to_amax(
x,
hp_tensor,
reduce_amax,
device_mesh,
scaling_granularity,
axiswise_dim,
)
return amax_to_scale(amax, float8_dtype)
return amax_to_scale(amax, float8_dtype, power_of_2_scale=power_of_2_scale)


def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
Expand Down
Loading