Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Feb 11, 2025
1 parent d42a590 commit 5702ea0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
13 changes: 13 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
unwrap_tensor_subclass,
)

Expand Down Expand Up @@ -806,6 +807,18 @@ def test_workflow_e2e_numerics(self, config):
Simple test of e2e int4_weight_only workflow, comparing numerics
to a bfloat16 baseline.
"""
if (
isinstance(
config,
(
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
),
)
and not is_sm_at_least_89()
):
return unittest.skip("requires CUDA capability 8.9 or greater")

# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
Expand Down
14 changes: 8 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,9 +1208,6 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
mm_config: Optional[Float8MMConfig] = None

def __post_init__(self):
assert (
is_sm_at_least_89() or is_MI300()
), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
if self.mm_config is None:
self.mm_config = Float8MMConfig(use_fast_accum=True)

Expand All @@ -1223,6 +1220,10 @@ def __post_init__(self):
def _float8_dynamic_activation_float8_weight_transform(
module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig
):
assert (
is_sm_at_least_89() or is_MI300()
), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"

activation_dtype = config.activation_dtype
weight_dtype = config.weight_dtype
granularity = config.granularity
Expand Down Expand Up @@ -1285,9 +1286,6 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
mm_config: Optional[Float8MMConfig] = None

def __post_init__(self):
assert (
is_sm_at_least_89() or is_MI300()
), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
if self.mm_config is None:
self.mm_config = Float8MMConfig(use_fast_accum=True)

Expand All @@ -1300,6 +1298,10 @@ def __post_init__(self):
def _float8_static_activation_float8_weight_transform(
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
):
assert (
is_sm_at_least_89() or is_MI300()
), "Float8 static activation quantization is only supported on CUDA 8.9 and above"

scale = config.scale
activation_dtype = config.activation_dtype
weight_dtype = config.weight_dtype
Expand Down

0 comments on commit 5702ea0

Please sign in to comment.