Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into gh/jainapurva/3/head
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Feb 4, 2025
2 parents d42c725 + b2fb664 commit 23f4a1c
Show file tree
Hide file tree
Showing 39 changed files with 2,546 additions and 437 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/doc_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ on:
tags:
- v[0-9]+.[0-9]+.[0-9]
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
pull_request:
paths:
- 'docs/**'
- '!docs/**'
pull_request:
workflow_dispatch:

concurrency:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.6
rev: v0.6.8
hooks:
# Run the linter.
- id: ruff
Expand Down
46 changes: 39 additions & 7 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,43 @@ torchao.quantization

.. currentmodule:: torchao.quantization

Main Quantization APIs
----------------------

.. autosummary::
:toctree: generated/
:nosignatures:

autoquant
quantize_
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
autoquant

Quantization APIs for quantize_
-------------------------------

.. autosummary::
:toctree: generated/
:nosignatures:

int4_weight_only
int8_weight_only
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
uintx_weight_only
gemlite_uintx_weight_only
intx_quantization_aware_training
from_intx_quantization_aware_training
float8_weight_only
float8_dynamic_activation_float8_weight
float8_static_activation_float8_weight
uintx_weight_only
fpx_weight_only
to_linear_activation_quantized
swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference

Quantization Primitives
-----------------------

.. autosummary::
:toctree: generated/
:nosignatures:

choose_qparams_affine
choose_qparams_affine_with_min_max
choose_qparams_affine_floatx
Expand All @@ -40,3 +59,16 @@ torchao.quantization
ZeroPointDomain
TorchAODType

..
TODO: delete these?
Other
-----

.. autosummary::
:toctree: generated/
:nosignatures:

to_linear_activation_quantized
swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference
6 changes: 3 additions & 3 deletions docs/source/api_ref_sparsity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ torchao.sparsity
:toctree: generated/
:nosignatures:

WandaSparsifier
PerChannelNormObserver
apply_fake_sparsity
sparsify_
semi_sparse_weight
int8_dynamic_activation_int8_semi_sparse_weight
apply_fake_sparsity
WandaSparsifier
PerChannelNormObserver
276 changes: 11 additions & 265 deletions docs/source/contributor_guide.rst

Large diffs are not rendered by default.

241 changes: 240 additions & 1 deletion docs/source/quantization.rst

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
float8_dynamic_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
from torchao.quantization.observer import PerRow, PerTensor
Expand Down Expand Up @@ -166,9 +167,21 @@ def test_tp_gemlite(self, dtype):
return self._test_tp(dtype)


class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight)
COMMON_DTYPES = [torch.bfloat16]

@common_utils.parametrize("dtype", COMMON_DTYPES)
@with_comms
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tp(self, dtype):
return self._test_tp(dtype)


common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel)
common_utils.instantiate_parametrized_tests(TestInt8dqAffineQuantizedTensorParallel)

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
Expand Down
47 changes: 39 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import os
import unittest
from functools import partial

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,6 +49,7 @@
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
dequantize_affine,
)
from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -102,6 +104,8 @@

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()


Expand All @@ -121,9 +125,18 @@ def _int8wo_groupwise_api(mod):
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)


def _int8da_int8w_api(mod):
def _int8da_int8w_api(
mod,
act_mapping_type=MappingType.SYMMETRIC,
):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(
mod,
int8_dynamic_activation_int8_weight(
act_mapping_type=act_mapping_type,
),
set_inductor_config=False,
)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
Expand Down Expand Up @@ -962,25 +975,43 @@ def _test_lin_weight_subclass_api_impl(
mod[0].weight.tensor_impl.get_plain()

test = mod(x)

self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
)

mod_qc = torch.compile(mod, mode="max-autotune")
test_comp = mod_qc(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
_int8da_int8w_api, device, 35, test_dtype=dtype
@parameterized.expand(
list(
itertools.product(
COMMON_DEVICES,
COMMON_DTYPES,
ACT_MAPPING_TYPES,
)
)
)
def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping):
if (
not TORCH_VERSION_AT_LEAST_2_5
and dtype in (torch.float16, torch.bfloat16)
and act_mapping is MappingType.ASYMMETRIC
and device == "cpu"
):
self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5")
api = partial(
_int8da_int8w_api,
act_mapping_type=act_mapping,
)
self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
Expand Down
8 changes: 7 additions & 1 deletion test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
sem_vals_to_f32,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100

torch.manual_seed(0)

Expand Down Expand Up @@ -310,6 +310,9 @@ def test_fp4_pack_unpack():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp4_triton_unscaled_cast():
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
Expand All @@ -320,6 +323,9 @@ def test_fp4_triton_unscaled_cast():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
Expand Down
12 changes: 11 additions & 1 deletion test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
swap_linear_with_mx_linear,
)
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(2)

Expand Down Expand Up @@ -99,6 +103,9 @@ def test_activation_checkpointing():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
Expand Down Expand Up @@ -184,6 +191,9 @@ def test_inference_linear(elem_dtype, bias, input_shape):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_inference_compile_simple(elem_dtype):
"""
Expand Down
11 changes: 10 additions & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
to_dtype,
)
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(2)

Expand Down Expand Up @@ -166,6 +170,8 @@ def test_transpose(elem_dtype, fp4_triton):
"""
if elem_dtype != DTYPE_FP4 and fp4_triton:
pytest.skip("unsupported configuration")
elif fp4_triton and is_sm_at_least_100():
pytest.skip("triton does not work yet on CUDA capability 10.0")

M, K = 128, 256
block_size = 32
Expand Down Expand Up @@ -205,6 +211,9 @@ def test_view(elem_dtype):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("all_zeros", [False, True])
Expand Down
37 changes: 33 additions & 4 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,24 @@ def test_optim_4bit_correctness(self, optim_name):
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = _DEVICES[-1]
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
# The first two layers are chosen so that they have a terrible arithmetic density.
# this means long transfers and comparatively quick computation, increasing the chances
# that missing synchronization will lead to test failures.
# The third layer is very small, here to validate non-trainable parameters,
# but shouldn't influence the timings
model1 = nn.Sequential(
nn.Linear(32, 131072),
nn.ReLU(),
nn.Linear(131072, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 128),
)
model1.to(device)

# make sure it can work in the presence of non-trainable params
model1[0].requires_grad_(False)
model1[2].requires_grad_(False)
model2 = copy.deepcopy(model1)

optim1 = torch.optim.AdamW(model1.parameters())
Expand All @@ -274,17 +287,33 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
offload_gradients=offload_grad,
)

scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR(optim1, 100)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, 100)

rng = torch.Generator(device=device)
rng.manual_seed(42)

# make sure to run both models separately; otherwise, model1 gives additional
# time for operations in model2 to complete, marking potential race conditions.
for _ in range(2):
for _ in range(grad_accum):
x = torch.randn(4, 32, device=device)
x = torch.randn(4, 32, device=device, generator=rng)
model1(x).sum().backward()
model2(x).sum().backward()

optim1.step()
optim1.zero_grad()
scheduler1.step()

# reset the rng
rng.manual_seed(42)
for _ in range(2):
for _ in range(grad_accum):
x = torch.randn(4, 32, device=device, generator=rng)
model2(x).sum().backward()

optim2.step()
optim2.zero_grad()
scheduler2.step()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)
Expand Down
Loading

0 comments on commit 23f4a1c

Please sign in to comment.