Skip to content

Commit

Permalink
config migration: float*
Browse files Browse the repository at this point in the history
Summary:

TODO write me

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 43c6f7dc8ee071ff040c02c2ba061cb36d65434a
ghstack-comment-id: 2649492752
Pull Request resolved: #1694
  • Loading branch information
vkuzo committed Feb 14, 2025
1 parent 7051ca0 commit da35915
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 99 deletions.
14 changes: 11 additions & 3 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
def _apply(module, config_or_subclass_inserter):
if isinstance(config_or_subclass_inserter, AOBaseConfig):
quantize_(module, config_or_subclass_inserter)
else:
# TODO(#1690): delete this once config migration is done
module = config_or_subclass_inserter(module)
return module

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.to("cuda")

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.to(device="cuda")

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.cuda()

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
41 changes: 36 additions & 5 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
Quantizer,
TwoStepQuantizer,
_replace_with_custom_fn_if_matches_filter,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
Expand All @@ -46,6 +49,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
unwrap_tensor_subclass,
)

Expand Down Expand Up @@ -784,28 +788,55 @@ def test_int4wo_cpu(self, dtype, x_dim):
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]

# TODO(#1690): move to new config names
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int4_weight_only_numerics(self):
@common_utils.parametrize(
"config",
[
int4_weight_only(),
float8_weight_only(),
float8_dynamic_activation_float8_weight(),
float8_static_activation_float8_weight(scale=torch.tensor([1.0])),
],
)
def test_workflow_e2e_numerics(self, config):
"""
Simple test of e2e int4_weight_only workflow, comparing numerics
to a bfloat16 baseline.
"""
if (
isinstance(
config,
(
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
),
)
and not is_sm_at_least_89()
):
return unittest.skip("requires CUDA capability 8.9 or greater")

# scale has to be moved to cuda here because the parametrization init
# code happens before gating for cuda availability
if isinstance(config, float8_static_activation_float8_weight):
config.scale = config.scale.to("cuda")

# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
m_int4_wo = copy.deepcopy(m_ref)
m_q = copy.deepcopy(m_ref)

# quantize
quantize_(m_int4_wo, int4_weight_only())
quantize_(m_q, config)

with torch.no_grad():
y_ref = m_ref(x)
y_int4_wo = m_int4_wo(x)
y_q = m_q(x)

sqnr = compute_error(y_ref, y_int4_wo)
sqnr = compute_error(y_ref, y_q)
assert sqnr >= 20, f"SQNR {sqnr} is too low"


Expand Down
6 changes: 6 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
AffineQuantizedObserverBase,
)
from .quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
Expand Down Expand Up @@ -121,6 +124,9 @@
"gemlite_uintx_weight_only",
"swap_conv2d_1x1_to_linear",
"Int4WeightOnlyConfig",
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
"Float8StaticActivationFloat8WeightConfig",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
Loading

0 comments on commit da35915

Please sign in to comment.