Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Hqq quantization to subclass #1604

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
# experimental, will be merged into floatx in the future
to_affine_quantized_fpx,
to_affine_quantized_intx,
to_affine_quantized_intx_static,
)
from .floatx import (
Float8Layout,
to_affine_quantized_fpx,
)
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
Expand All @@ -22,6 +22,7 @@
SemiSparseLayout,
TensorCoreTiledLayout,
UintxLayout,
to_hqq_quantized_intx,
to_marlinqqq_quantized_intx,
)
from .utils import (
Expand Down Expand Up @@ -52,4 +53,5 @@
"MarlinQQQLayout",
"Int4CPULayout",
"CutlassInt4PackedLayout",
"to_hqq_quantized_intx",
]
181 changes: 49 additions & 132 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
quantize_affine,
quantize_affine_floatx,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
Expand All @@ -36,7 +32,6 @@
"to_affine_quantized_floatx",
"to_affine_quantized_intx_static",
"to_affine_quantized_floatx_static",
"to_affine_quantized_fpx",
]


Expand Down Expand Up @@ -126,40 +121,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
if output_dtype is None:
output_dtype = self.dtype

from torchao.dtypes.floatx import FloatxTensorCoreLayout

if isinstance(self._layout, FloatxTensorCoreLayout):
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_floatx(
int_data,
scale,
self._layout.ebits,
self._layout.mbits,
output_dtype=output_dtype,
)
else:
data, scale, zero_point = self.tensor_impl.get_plain()
dq = dequantize_affine(
data,
self.block_size,
scale,
zero_point,
data.dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout
data, scale, zero_point = self.tensor_impl.get_plain()
dq = dequantize_affine(
data,
self.block_size,
scale,
zero_point,
data.dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout

if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq
if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq

def __tensor_flatten__(self):
return ["tensor_impl"], [
Expand Down Expand Up @@ -210,71 +193,34 @@ def from_hp_to_intx(
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)

if use_hqq:
assert (
zero_point_domain == ZeroPointDomain.FLOAT
and mapping_type == MappingType.ASYMMETRIC
and quant_min == 0
), "Invalid input parameters for HQQ quantization."
nbits = int(math.log2(quant_max + 1))
axis = 1 if (block_size[0] == 1) else 0
group_size = max(block_size)
compute_dtype = (
zero_point_dtype
if (zero_point_dtype is not None)
else input_float.dtype
)
device = input_float.device
from torchao.dtypes.uintx import TensorCoreTiledLayout

data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(
input_float,
nbits=nbits,
group_size=group_size,
axis=axis,
compute_dtype=compute_dtype,
device=device,
verbose=False,
raw_output=not isinstance(
_layout, (TensorCoreTiledLayout, PlainLayout)
),
# raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint)
# note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether
# zero is preserved.
# TODO uncouple preserve_zero and conversion of zero_point to TensorCoreTiledLayout version
# TODO move the conversion of zero_point out of quant_primitives and into TensorCoreTiledLayout.from_plain
# TODO change PlainLayout to use raw_output.
)
data = data.to(target_dtype)
else:
scale, zero_point = choose_qparams_affine(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
# Note: output will be uint8 tensor for sub byte tensors for now
scale, zero_point = choose_qparams_affine(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
# Note: output will be uint8 tensor for sub byte tensors for now

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
Expand Down Expand Up @@ -395,33 +341,6 @@ def from_hp_to_floatx_static(
f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static"
)

@classmethod
def from_hp_to_fpx(
cls,
input_float: torch.Tensor,
_layout: Layout,
):
from torchao.dtypes.floatx import FloatxTensorCoreLayout

assert isinstance(
_layout, FloatxTensorCoreLayout
), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}"
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
# per axis quantization, where axis = 1
block_size = list(input_float.shape)
block_size[1] = 1

ebits, mbits = _layout.ebits, _layout.mbits
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
floatx_packed = _layout.post_process(floatx_unpacked)

tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)

@property
def _layout(self) -> Layout:
return self.tensor_impl._layout
Expand Down Expand Up @@ -477,8 +396,6 @@ def _apply_fn_to_data(self, fn):
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
# experimental will be merged in to floatx
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx

if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`
Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/floatx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .floatx_tensor_core_layout import (
FloatxTensorCoreLayout,
from_scaled_tc_floatx,
to_affine_quantized_fpx,
to_scaled_tc_floatx,
)

Expand All @@ -10,4 +11,5 @@
"to_scaled_tc_floatx",
"from_scaled_tc_floatx",
"Float8Layout",
"to_affine_quantized_fpx",
]
56 changes: 56 additions & 0 deletions torchao/dtypes/floatx/floatx_tensor_core_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
get_tensor_impl_constructor,
register_layout,
)
from torchao.dtypes.utils import (
Expand All @@ -22,6 +23,11 @@
_floatx_unpacked_to_f32,
_n_ones,
)
from torchao.quantization.quant_primitives import (
choose_qparams_affine_floatx,
dequantize_affine_floatx,
quantize_affine_floatx,
)

aten = torch.ops.aten
_ONES_TABLE = [_n_ones(i) for i in range(8)]
Expand Down Expand Up @@ -456,6 +462,53 @@ class FloatxTensorCoreLayout(Layout):
mbits: int


class FloatxTensor(AffineQuantizedTensor):
"""
Floatx quantized tensor subclass which inherits AffineQuantizedTensor class.

To see what happens during choose_qparams_and_quantize_affine_fpx, quantization and dequantization for floatx quantization,
please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
and check the two quant primitive ops: choose_qparams_affine_floatx, quantize_affine_floatx and dequantize_affine_floatx.
"""

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_floatx(
int_data,
scale,
self._layout.ebits,
self._layout.mbits,
output_dtype=output_dtype,
)

@classmethod
def from_hp_to_floatx(
cls,
input_float: torch.Tensor,
_layout: Layout,
):
assert isinstance(
_layout, FloatxTensorCoreLayout
), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}"
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
# per axis quantization, where axis = 1
block_size = list(input_float.shape)
block_size[1] = 1

ebits, mbits = _layout.ebits, _layout.mbits
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
floatx_packed = _layout.post_process(floatx_unpacked)

tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)


@register_layout(FloatxTensorCoreLayout)
class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl):
"""FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b),
Expand Down Expand Up @@ -657,3 +710,6 @@ def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias):
out += bias

return out.view(*act.shape[:-1], out_dim).to(act.dtype)


to_affine_quantized_fpx = FloatxTensor.from_hp_to_floatx
4 changes: 4 additions & 0 deletions torchao/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from .cutlass_int4_packed_layout import (
CutlassInt4PackedLayout,
)
from .hqq_tensor import (
to_hqq_quantized_intx,
)
from .int4_cpu_layout import (
Int4CPULayout,
)
Expand Down Expand Up @@ -36,4 +39,5 @@
"MarlinQQQTensor",
"to_marlinqqq_quantized_intx",
"CutlassInt4PackedLayout",
"to_hqq_quantized_intx",
]
Loading
Loading