Skip to content

Commit

Permalink
[wip] configs configs configs!
Browse files Browse the repository at this point in the history
Summary:

POC for:

* decoupling configuration from transformation
* stop passing obscure stateful callables around
* enable printing of configuration
* reduce amount of context switching to navigate the logic from `quantize_` to
  quantizing a single module

TODO more polish before wider discussion.

Test Plan:

```
pytest test/quantization/test_quant_api.py -s -x -k test_int4_weight_only_numerics
pytest test/quantization/test_qat.py -s -x -k test_quantize_api_standalone
pytest test/quantization/test_qat.py -s -x -k test_quantize_api_convert_path
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 73e9a5c3bf03e2cb645cc0ea43bec162a5f4897e
ghstack-comment-id: 2607756510
Pull Request resolved: #1595
  • Loading branch information
vkuzo committed Jan 22, 2025
1 parent 32d9b0b commit 2307f5b
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 143 deletions.
7 changes: 6 additions & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
run_tests,
)

from torchao.core.config import AOBaseWorkflowConfig
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.utils import (
Expand Down Expand Up @@ -186,7 +188,10 @@ def test_flatten_unflatten(self, device, dtype):
apply_quant_list = get_quantization_functions(False, True, device)
for apply_quant in apply_quant_list:
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseWorkflowConfig):
quantize_(linear, apply_quant)
else:
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api(self):
def test_quantize_api_standalone(self):
"""
Test that the following:
Expand Down
25 changes: 25 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
Expand Down Expand Up @@ -761,6 +762,30 @@ def reset_memory():
assert param.is_cuda
self.assertLess(memory_streaming, memory_baseline)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int4_weight_only_numerics(self):
"""
Simple test of e2e int4_weight_only workflow, comparing numerics
to a bfloat16 baseline.
"""
# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
m_int4_wo = copy.deepcopy(m_ref)

# quantize
quantize_(m_int4_wo, int4_weight_only())

with torch.no_grad():
y_ref = m_ref(x)
y_int4_wo = m_int4_wo(x)

sqnr = compute_error(y_ref, y_int4_wo)
assert sqnr >= 20, f"SQNR {sqnr} is too low"


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
Expand Down
Empty file added torchao/core/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions torchao/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import abc


# directory location for this might need more polish
class AOBaseWorkflowConfig(abc.ABC):
"""
If a workflow config inherits from this then `quantize_` knows
what to do with it.
TODO write a better docblock.
"""

pass
17 changes: 17 additions & 0 deletions torchao/quantization/_transform_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Callable, Dict

import torch

from torchao.core.config import AOBaseWorkflowConfig

_QUANTIZE_CONFIG_HANDLER: Dict[
AOBaseWorkflowConfig,
Callable[[torch.nn.Module, AOBaseWorkflowConfig], torch.nn.Module],
] = {}


def register_quantize_module_handler(config_type):
def decorator(func):
_QUANTIZE_CONFIG_HANDLER[config_type] = func

return decorator
114 changes: 68 additions & 46 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union
from typing import Any, List, Optional, Union

import torch

from torchao.core.config import AOBaseWorkflowConfig
from torchao.quantization._transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.granularity import (
Granularity,
PerAxis,
Expand Down Expand Up @@ -239,12 +243,26 @@ def __setattr__(self, name: str, value: Any):
super().__setattr__(name, value)


def intx_quantization_aware_training(
activation_config: Optional[FakeQuantizeConfig] = None,
weight_config: Optional[FakeQuantizeConfig] = None,
) -> Callable:
@dataclass
class IntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig):
activation_config: Optional[FakeQuantizeConfig] = None
weight_config: Optional[FakeQuantizeConfig] = None


# for BC
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig


@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
def _intx_quantization_aware_training_transform(
module: torch.nn.Module,
config: IntXQuantizationAwareTrainingConfig,
) -> torch.nn.Module:
"""
Return a function that applies fake quantization to a `torch.nn.Module`.
THIS IS NOT A PUBLIC API - any usage of this outside of torchao
can break at any time.
Apply fake quantization to a `torch.nn.Module`.
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
Example usage::
Expand All @@ -267,37 +285,32 @@ def intx_quantization_aware_training(
`torch.nn.Embedding` with an activation config, then we will raise
ValueError as these are not supported.
"""

def _insert_fake_quantize(mod: torch.nn.Module):
"""
Swap the given module with its corresponding fake quantized version.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(
mod,
activation_config,
weight_config,
)
elif isinstance(mod, torch.nn.Embedding):
if activation_config is not None:
raise ValueError(
"Activation fake quantization is not supported for embedding"
)
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
else:
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

mod = module
activation_config = config.activation_config
weight_config = config.weight_config

if isinstance(mod, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(
mod,
activation_config,
weight_config,
)
elif isinstance(mod, torch.nn.Embedding):
if activation_config is not None:
raise ValueError(
"Module of type '%s' does not have QAT support" % type(mod)
"Activation fake quantization is not supported for embedding"
)
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
else:
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))

return _insert_fake_quantize


def from_intx_quantization_aware_training() -> Callable:
class FromIntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig):
"""
Return a function that converts a model with fake quantized modules,
Object that knows how to convert a model with fake quantized modules,
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
back to model with the original, corresponding modules without
Expand All @@ -313,22 +326,31 @@ def from_intx_quantization_aware_training() -> Callable:
)
"""

def _remove_fake_quantize(mod: torch.nn.Module):
"""
If the given module is a fake quantized module, return the original
corresponding version of the module without fake quantization.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear
pass


# for BC
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig

if isinstance(mod, FakeQuantizedLinear):
return mod.to_linear()
elif isinstance(mod, FakeQuantizedEmbedding):
return mod.to_embedding()
else:
return mod

return _remove_fake_quantize
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)
def _from_intx_quantization_aware_training_transform(
mod: torch.nn.Module,
config: FromIntXQuantizationAwareTrainingConfig,
) -> torch.nn.Module:
"""
If the given module is a fake quantized module, return the original
corresponding version of the module without fake quantization.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, FakeQuantizedLinear):
return mod.to_linear()
elif isinstance(mod, FakeQuantizedEmbedding):
return mod.to_embedding()
else:
return mod


class ComposableQATQuantizer(TwoStepQuantizer):
Expand Down
Loading

0 comments on commit 2307f5b

Please sign in to comment.