Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create separate float8 tensor subclass #1636

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 0 additions & 58 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,64 +337,6 @@ def from_hp_to_intx_static(
dtype=input_float.dtype,
)

@classmethod
def from_hp_to_floatx(
cls,
input_float: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
_layout: Layout,
scale_dtype: Optional[torch.dtype] = None,
):
"""Convert a high precision tensor to a float8 quantized tensor."""
if target_dtype in FP8_TYPES:
return cls.from_hp_to_intx(
input_float=input_float,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
eps=torch.finfo(torch.float32).eps,
scale_dtype=scale_dtype,
zero_point_dtype=None,
preserve_zero=True,
zero_point_domain=None,
_layout=_layout,
use_hqq=False,
)
else:
raise NotImplementedError(
f"Unsupported dtype {target_dtype} for from_hp_to_floatx"
)

@classmethod
def from_hp_to_floatx_static(
cls,
input_float: torch.Tensor,
scale: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
_layout: Layout,
):
"""Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters."""
if target_dtype in FP8_TYPES:
return cls.from_hp_to_intx_static(
input_float=input_float,
scale=scale,
zero_point=None,
block_size=block_size,
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
zero_point_domain=None,
_layout=_layout,
)
else:
raise NotImplementedError(
f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static"
)

@classmethod
def from_hp_to_fpx(
cls,
Expand Down
239 changes: 147 additions & 92 deletions torchao/dtypes/floatx/float8_layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import math

import torch
from torch.utils._python_dispatch import (
Expand All @@ -11,15 +12,21 @@
AffineQuantizedTensor,
register_layout,
)
from torchao.dtypes.nf4tensor import implements
from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape
from torchao.float8.inference import (
Float8MMConfig,
_is_rowwise_scaled,
addmm_float8_unwrapped_inference,
preprocess_data,
)
from torchao.utils import _is_float8_type, fill_defaults

from torchao.utils import _is_float8_type, fill_defaults, TorchAOBaseTensor
from torchao.quantization.quant_primitives import (
FP8_TYPES,
MappingType,
choose_qparams_affine_float8,
quantize_affine_float8,
)
aten = torch.ops.aten


Expand All @@ -34,13 +41,16 @@ class Float8Layout(Layout):
mm_config: Optional[Float8MMConfig] = None


@register_layout(Float8Layout)
class Float8AQTTensorImpl(AQTTensorImpl):
class Float8Tensor(TorchAOBaseTensor):
"""
TensorImpl for float8 layout affine quantized tensor
Float8 Tensor is a subclass of torch.Tensor that supports float8 data types.
It is used to represent the data in a float8 tensor.

Note: technically we should not create a new layout for float8 we should merge this into
plain layout
Attributes:
float8_data (torch.Tensor): The float8 data tensor.
scale (torch.Tensor): The scale tensor.
transposed (bool): Whether the tensor is transposed or not.
_layout (Layout): The layout of the tensor.
"""

float8_data: torch.Tensor
Expand All @@ -52,7 +62,7 @@ def __new__(
float8_data: torch.Tensor,
scale: torch.Tensor,
transposed: bool,
_layout: Layout,
_layout: Layout = Float8Layout(),
):
kwargs = {}
kwargs["device"] = float8_data.device
Expand All @@ -69,7 +79,7 @@ def __init__(
float8_data: torch.Tensor,
scale: torch.Tensor,
transposed: bool,
_layout: Layout,
_layout: Layout = Float8Layout(),
):
self.float8_data = float8_data
self.scale = scale
Expand Down Expand Up @@ -108,84 +118,20 @@ def __tensor_unflatten__(
) = tensor_attributes
return cls(float8_data, scale, transposed, _layout)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
elif func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
elif func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
"""
args[0].transposed = not args[0].transposed
return return_and_correct_aliasing(func, args, kwargs, args[0])
elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
# TODO: scale replecation should be dependent on block size
if self.scale.ndim == 1:
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0]._apply_fn_to_data(
lambda x: aten.slice.Tensor(x, dim, start, end, step)
),
)
elif self.scale.ndim == 0:
return return_and_correct_aliasing(
func,
args,
kwargs,
Float8AQTTensorImpl(
aten.slice.Tensor(self.float8_data, dim, start, end, step),
self.scale,
None,
self._layout,
),
)
else:
raise NotImplementedError(
f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported"
)
elif dim == 1:
return return_and_correct_aliasing(
func,
args,
kwargs,
Float8AQTTensorImpl(
aten.slice.Tensor(
self.float8_data, dim, start, end, step
).contiguous(),
self.scale,
None,
self._layout,
),
)
else:
raise NotImplementedError(
f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
)
else:
raise NotImplementedError(
f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl
def __repr__(self):
float8_data, scale, _ = self.get_plain()
_layout = self.get_layout()
return (
f"{self.__class__.__name__}(\n"
f"float8_data={float8_data},\n"
f"scale={scale},\n"
f"transposed={self.transposed}, "
f"_layout={_layout})"
)

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.float8_data, self.scale, None

def get_layout(self) -> Layout:
return self._layout

@classmethod
def from_plain(
cls,
Expand All @@ -203,15 +149,120 @@ def from_plain(
), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}"
return cls(data, scale, False, _layout)

def __repr__(self):
float8_data, scale, _ = self.get_plain()
_layout = self.get_layout()
return (
f"{self.__class__.__name__}(\n"
f"float8_data={float8_data},\n"
f"scale={scale},\n"
f"transposed={self.transposed}, "
f"_layout={_layout})"
@classmethod
def from_hp_to_floatx(
cls,
input_float: torch.Tensor,
target_dtype: torch.dtype,
_layout: Layout = Float8Layout(),
):
"""Convert a high precision tensor to a float8 quantized tensor."""
if target_dtype not in FP8_TYPES:
raise NotImplementedError(
f"Unsupported dtype {target_dtype} for from_hp_to_floatx"
)
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
)
float_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)

return cls(
float_data,
scale,
False,
_layout,
)

@classmethod
def from_hp_to_floatx_static(
cls,
input_float: torch.Tensor,
scale: torch.Tensor,
target_dtype: torch.dtype,
_layout: Layout,
):
"""Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters."""
if target_dtype not in FP8_TYPES:
raise NotImplementedError(
f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static"
)
float_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)

return cls(
float_data,
scale,
False,
_layout,
)

__torch_function__ = torch._C._disabled_torch_function_impl


@implements(aten.t.default)
def _(func, types, args, kwargs):
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
"""
args[0].transposed = not args[0].transposed
return return_and_correct_aliasing(func, args, kwargs, args[0])


@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
# TODO: scale replecation should be dependent on block size
if self.scale.ndim == 1:
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0]._apply_fn_to_data(
lambda x: aten.slice.Tensor(x, dim, start, end, step)
),
)
elif self.scale.ndim == 0:
return return_and_correct_aliasing(
func,
args,
kwargs,
Float8Tensor(
aten.slice.Tensor(self.float8_data, dim, start, end, step),
self.scale,
self.transposed,
self._layout,
),
)
else:
raise NotImplementedError(
f"Float8Tensor dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported"
)
elif dim == 1:
return return_and_correct_aliasing(
func,
args,
kwargs,
Float8Tensor(
aten.slice.Tensor(
self.float8_data, dim, start, end, step
).contiguous(),
self.scale,
self.transposed,
self._layout,
),
)
else:
raise NotImplementedError(
f"Float8Tensor dispatch: attempting to run {func}, with dim={dim}, that is not supported"
)


Expand Down Expand Up @@ -317,3 +368,7 @@ def _linear_fp_act_fp8_weight_impl(
bias: Optional[torch.Tensor],
):
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)


to_quantized_float8 = Float8Tensor.from_hp_to_floatx
to_quantized_float8_static = Float8Tensor.from_hp_to_float8_static
Loading