Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add boilerplate code #1635

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/doc_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ on:
tags:
- v[0-9]+.[0-9]+.[0-9]
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
pull_request:
paths:
- 'docs/**'
- '!docs/**'
pull_request:
workflow_dispatch:

concurrency:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.6
rev: v0.6.8
hooks:
# Run the linter.
- id: ruff
Expand Down
46 changes: 39 additions & 7 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,43 @@ torchao.quantization

.. currentmodule:: torchao.quantization

Main Quantization APIs
----------------------

.. autosummary::
:toctree: generated/
:nosignatures:

autoquant
quantize_
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
autoquant

Quantization APIs for quantize_
-------------------------------

.. autosummary::
:toctree: generated/
:nosignatures:

int4_weight_only
int8_weight_only
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
uintx_weight_only
gemlite_uintx_weight_only
intx_quantization_aware_training
from_intx_quantization_aware_training
float8_weight_only
float8_dynamic_activation_float8_weight
float8_static_activation_float8_weight
uintx_weight_only
fpx_weight_only
to_linear_activation_quantized
swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference

Quantization Primitives
-----------------------

.. autosummary::
:toctree: generated/
:nosignatures:

choose_qparams_affine
choose_qparams_affine_with_min_max
choose_qparams_affine_floatx
Expand All @@ -40,3 +59,16 @@ torchao.quantization
ZeroPointDomain
TorchAODType

..
TODO: delete these?

Other
-----

.. autosummary::
:toctree: generated/
:nosignatures:

to_linear_activation_quantized
swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference
6 changes: 3 additions & 3 deletions docs/source/api_ref_sparsity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ torchao.sparsity
:toctree: generated/
:nosignatures:

WandaSparsifier
PerChannelNormObserver
apply_fake_sparsity
sparsify_
semi_sparse_weight
int8_dynamic_activation_int8_semi_sparse_weight
apply_fake_sparsity
WandaSparsifier
PerChannelNormObserver
276 changes: 11 additions & 265 deletions docs/source/contributor_guide.rst

Large diffs are not rendered by default.

241 changes: 240 additions & 1 deletion docs/source/quantization.rst

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
float8_dynamic_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
from torchao.quantization.observer import PerRow, PerTensor
Expand Down Expand Up @@ -166,9 +167,21 @@ def test_tp_gemlite(self, dtype):
return self._test_tp(dtype)


class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight)
COMMON_DTYPES = [torch.bfloat16]

@common_utils.parametrize("dtype", COMMON_DTYPES)
@with_comms
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tp(self, dtype):
return self._test_tp(dtype)


common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel)
common_utils.instantiate_parametrized_tests(TestInt8dqAffineQuantizedTensorParallel)

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
Expand Down
47 changes: 39 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import os
import unittest
from functools import partial

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,6 +49,7 @@
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
dequantize_affine,
)
from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -102,6 +104,8 @@

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()


Expand All @@ -121,9 +125,18 @@ def _int8wo_groupwise_api(mod):
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)


def _int8da_int8w_api(mod):
def _int8da_int8w_api(
mod,
act_mapping_type=MappingType.SYMMETRIC,
):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(
mod,
int8_dynamic_activation_int8_weight(
act_mapping_type=act_mapping_type,
),
set_inductor_config=False,
)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
Expand Down Expand Up @@ -962,25 +975,43 @@ def _test_lin_weight_subclass_api_impl(
mod[0].weight.tensor_impl.get_plain()

test = mod(x)

self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
)

mod_qc = torch.compile(mod, mode="max-autotune")
test_comp = mod_qc(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
_int8da_int8w_api, device, 35, test_dtype=dtype
@parameterized.expand(
list(
itertools.product(
COMMON_DEVICES,
COMMON_DTYPES,
ACT_MAPPING_TYPES,
)
)
)
def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping):
if (
not TORCH_VERSION_AT_LEAST_2_5
and dtype in (torch.float16, torch.bfloat16)
and act_mapping is MappingType.ASYMMETRIC
and device == "cpu"
):
self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5")
api = partial(
_int8da_int8w_api,
act_mapping_type=act_mapping,
)
self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
Expand Down
8 changes: 7 additions & 1 deletion test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
sem_vals_to_f32,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100

torch.manual_seed(0)

Expand Down Expand Up @@ -310,6 +310,9 @@ def test_fp4_pack_unpack():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp4_triton_unscaled_cast():
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
Expand All @@ -320,6 +323,9 @@ def test_fp4_triton_unscaled_cast():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
Expand Down
12 changes: 11 additions & 1 deletion test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
swap_linear_with_mx_linear,
)
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(2)

Expand Down Expand Up @@ -99,6 +103,9 @@ def test_activation_checkpointing():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
Expand Down Expand Up @@ -184,6 +191,9 @@ def test_inference_linear(elem_dtype, bias, input_shape):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_inference_compile_simple(elem_dtype):
"""
Expand Down
11 changes: 10 additions & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
to_dtype,
)
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(2)

Expand Down Expand Up @@ -166,6 +170,8 @@ def test_transpose(elem_dtype, fp4_triton):
"""
if elem_dtype != DTYPE_FP4 and fp4_triton:
pytest.skip("unsupported configuration")
elif fp4_triton and is_sm_at_least_100():
pytest.skip("triton does not work yet on CUDA capability 10.0")

M, K = 128, 256
block_size = 32
Expand Down Expand Up @@ -205,6 +211,9 @@ def test_view(elem_dtype):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("all_zeros", [False, True])
Expand Down
37 changes: 33 additions & 4 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,24 @@ def test_optim_4bit_correctness(self, optim_name):
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = _DEVICES[-1]
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
# The first two layers are chosen so that they have a terrible arithmetic density.
# this means long transfers and comparatively quick computation, increasing the chances
# that missing synchronization will lead to test failures.
# The third layer is very small, here to validate non-trainable parameters,
# but shouldn't influence the timings
model1 = nn.Sequential(
nn.Linear(32, 131072),
nn.ReLU(),
nn.Linear(131072, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 128),
)
model1.to(device)

# make sure it can work in the presence of non-trainable params
model1[0].requires_grad_(False)
model1[2].requires_grad_(False)
model2 = copy.deepcopy(model1)

optim1 = torch.optim.AdamW(model1.parameters())
Expand All @@ -274,17 +287,33 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
offload_gradients=offload_grad,
)

scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR(optim1, 100)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, 100)

rng = torch.Generator(device=device)
rng.manual_seed(42)

# make sure to run both models separately; otherwise, model1 gives additional
# time for operations in model2 to complete, marking potential race conditions.
for _ in range(2):
for _ in range(grad_accum):
x = torch.randn(4, 32, device=device)
x = torch.randn(4, 32, device=device, generator=rng)
model1(x).sum().backward()
model2(x).sum().backward()

optim1.step()
optim1.zero_grad()
scheduler1.step()

# reset the rng
rng.manual_seed(42)
for _ in range(2):
for _ in range(grad_accum):
x = torch.randn(4, 32, device=device, generator=rng)
model2(x).sum().backward()

optim2.step()
optim2.zero_grad()
scheduler2.step()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)
Expand Down
Loading
Loading