Skip to content

Commit

Permalink
config migration: float*
Browse files Browse the repository at this point in the history
Summary:

TODO write me

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 26d4a037f251363bb27638078134920d279df1a9
ghstack-comment-id: 2649492752
Pull Request resolved: #1694
  • Loading branch information
vkuzo committed Feb 10, 2025
1 parent b678cc6 commit e10e222
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 105 deletions.
14 changes: 11 additions & 3 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
def _apply(module, config_or_subclass_inserter):
if isinstance(config_or_subclass_inserter, AOBaseConfig):
quantize_(module, config_or_subclass_inserter)
else:
# TODO(#1690): delete this once config migration is done
module = config_or_subclass_inserter(module)
return module

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.to("cuda")

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.to(device="cuda")

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.cuda()

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
25 changes: 20 additions & 5 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
Quantizer,
TwoStepQuantizer,
_replace_with_custom_fn_if_matches_filter,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
Expand Down Expand Up @@ -784,9 +787,21 @@ def test_int4wo_cpu(self, dtype, x_dim):
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]

# TODO(#1690): move to new config names
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int4_weight_only_numerics(self):
@common_utils.parametrize(
"config",
[
int4_weight_only(),
float8_weight_only(),
float8_dynamic_activation_float8_weight(),
float8_static_activation_float8_weight(
scale=torch.tensor([1.0], device="cuda")
),
],
)
def test_workflow_e2e_numerics(self, config):
"""
Simple test of e2e int4_weight_only workflow, comparing numerics
to a bfloat16 baseline.
Expand All @@ -796,16 +811,16 @@ def test_int4_weight_only_numerics(self):
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
m_int4_wo = copy.deepcopy(m_ref)
m_q = copy.deepcopy(m_ref)

# quantize
quantize_(m_int4_wo, int4_weight_only())
quantize_(m_q, config)

with torch.no_grad():
y_ref = m_ref(x)
y_int4_wo = m_int4_wo(x)
y_q = m_q(x)

sqnr = compute_error(y_ref, y_int4_wo)
sqnr = compute_error(y_ref, y_q)
assert sqnr >= 20, f"SQNR {sqnr} is too low"


Expand Down
246 changes: 149 additions & 97 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,30 +1030,43 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())


def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
@dataclass
class Float8WeightOnlyConfig(AOBaseConfig):
"""
Applies float8 weight-only symmetric per-channel quantization to linear layers.
Configuration for applying float8 weight-only symmetric per-channel quantization to linear layers.
Args:
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
Note:
The actual matmul will be computed in original precision of the weight tensor.
"""
from torchao.dtypes import to_affine_quantized_floatx

def apply_float8wo_quant(weight):
block_size = (1, weight.shape[1])
return to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=None,
_layout=Float8Layout(mm_config=None),
)
weight_dtype: torch.dtype = torch.float8_e4m3fn


# for BC
float8_weight_only = Float8WeightOnlyConfig


@register_quantize_module_handler(Float8WeightOnlyConfig)
def _float8_weight_only_transform(
module: torch.nn.Module, config: Float8WeightOnlyConfig
) -> torch.nn.Module:
from torchao.dtypes import to_affine_quantized_floatx

return _get_linear_subclass_inserter(apply_float8wo_quant)
weight = module.weight
block_size = (1, weight.shape[1])
new_weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=config.weight_dtype,
scale_dtype=None,
_layout=Float8Layout(mm_config=None),
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


_fp8_granularities = Union[PerTensor, PerRow]
Expand Down Expand Up @@ -1170,16 +1183,10 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool:
return is_compatible


def float8_dynamic_activation_float8_weight(
activation_dtype: torch.dtype = torch.float8_e4m3fn,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
granularity: Optional[
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
] = None,
mm_config: Optional[Float8MMConfig] = None,
):
@dataclass
class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
"""
Applies float8 dynamic symmetric quantization to both activations and weights of linear layers.
Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers.
Args:
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn.
Expand All @@ -1192,104 +1199,149 @@ def float8_dynamic_activation_float8_weight(
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
"""
assert (
is_sm_at_least_89() or is_MI300()
), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
if mm_config is None:
mm_config = Float8MMConfig(use_fast_accum=True)

activation_granularity, weight_granularity = _normalize_granularity(granularity)
activation_dtype: torch.dtype = torch.float8_e4m3fn
weight_dtype: torch.dtype = torch.float8_e4m3fn
granularity: Optional[
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
] = None
mm_config: Optional[Float8MMConfig] = None

def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
if not _fp8_mm_compat(weight):
return weight
if isinstance(weight_granularity, PerRow):
assert (
weight.dtype == torch.bfloat16
), "PerRow quantization only works for bfloat16 precision input weight"
def __post_init__(self):
assert (
is_sm_at_least_89() or is_MI300()
), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
if self.mm_config is None:
self.mm_config = Float8MMConfig(use_fast_accum=True)

block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=mm_config),
)

input_quant_func = _input_activation_quant_func_fp8
input_quant_kwargs = {
"activation_granularity": activation_granularity,
"activation_dtype": activation_dtype,
}
# for bc
float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig

quantized_weight = to_linear_activation_quantized(
quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
)
return quantized_weight

return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant)
@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig)
def _float8_dynamic_activation_float8_weight_transform(
module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig
):
activation_dtype = config.activation_dtype
weight_dtype = config.weight_dtype
granularity = config.granularity
mm_config = config.mm_config
weight = module.weight

activation_granularity, weight_granularity = _normalize_granularity(granularity)

def float8_static_activation_float8_weight(
scale: torch.Tensor,
activation_dtype: torch.dtype = torch.float8_e4m3fn,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
granularity: Optional[
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
] = None,
mm_config: Optional[Float8MMConfig] = None,
):
if not _fp8_mm_compat(weight):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
return module
if isinstance(weight_granularity, PerRow):
assert (
weight.dtype == torch.bfloat16
), "PerRow quantization only works for bfloat16 precision input weight"

block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=mm_config),
)

input_quant_func = _input_activation_quant_func_fp8
input_quant_kwargs = {
"activation_granularity": activation_granularity,
"activation_dtype": activation_dtype,
}

quantized_weight = to_linear_activation_quantized(
quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
)

module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


@dataclass
class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
"""
Applies float8 static symmetric quantization to
Configuration for applying float8 static symmetric quantization to
Args:
scale (torch.Tensor): The scale tensor for activation quantization.
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
"""
assert (
is_sm_at_least_89() or is_MI300()
), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
if mm_config is None:
mm_config = Float8MMConfig(use_fast_accum=True)

scale: torch.Tensor
activation_dtype: torch.dtype = torch.float8_e4m3fn
weight_dtype: torch.dtype = torch.float8_e4m3fn
granularity: Optional[
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
] = None
mm_config: Optional[Float8MMConfig] = None

def __post_init__(self):
assert (
is_sm_at_least_89() or is_MI300()
), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
if self.mm_config is None:
self.mm_config = Float8MMConfig(use_fast_accum=True)


# for bc
float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig


@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
def _float8_static_activation_float8_weight_transform(
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
):
scale = config.scale
activation_dtype = config.activation_dtype
weight_dtype = config.weight_dtype
granularity = config.granularity
mm_config = config.mm_config

weight = module.weight
activation_granularity, weight_granularity = _normalize_granularity(granularity)
assert isinstance(
activation_granularity, PerTensor
), "Static quantization only supports PerTensor granularity"

def apply_float8_static_activation_quant(weight: torch.Tensor):
if not _fp8_mm_compat(weight):
return weight
block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=mm_config),
)
if not _fp8_mm_compat(weight):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
return module
block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=mm_config),
)

input_quant_func = _input_activation_quant_func_fp8
input_quant_kwargs = {
"activation_granularity": activation_granularity,
"activation_dtype": activation_dtype,
}

quantized_weight = (
to_weight_tensor_with_linear_activation_quantization_metadata(
quantized_weight,
input_quant_func,
scale=scale,
zero_point=None,
quant_kwargs=input_quant_kwargs,
)
)
return quantized_weight
input_quant_func = _input_activation_quant_func_fp8
input_quant_kwargs = {
"activation_granularity": activation_granularity,
"activation_dtype": activation_dtype,
}

return _get_linear_subclass_inserter(apply_float8_static_activation_quant)
quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(
quantized_weight,
input_quant_func,
scale=scale,
zero_point=None,
quant_kwargs=input_quant_kwargs,
)

module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
Expand Down

0 comments on commit e10e222

Please sign in to comment.