Skip to content

Commit

Permalink
Add CUTLASS-based W4A4 (#1515)
Browse files Browse the repository at this point in the history
* add w4a4

* add test

* hook up to AQT

* fix quant api test

* fix test

* make threadblockswizzle a template param

* re-use s8s4 cutlass template

* add Alex's patch and some changes

* fix aqt test

* remove int4_cutlass.cu

* apply alex's patch

* update benchmark script

* ruff

* add some tuning

* reduce num_stages to fit shared memory of small GPUs (<100kb)

* replace torch timer with triton do_bench

* ruff

* use ZeroPointDomain.NONE

* fix 3.7 typing

* merge Aleksandar changes

* run ruff

* try replace torch/extension.h with torch/library.h

* (alexsamardzic) improve error handling

* ruff format

* add note on cutlass naming
  • Loading branch information
gau-nernst authored Feb 5, 2025
1 parent b2fb664 commit 1a4c8f9
Show file tree
Hide file tree
Showing 17 changed files with 734 additions and 444 deletions.
70 changes: 70 additions & 0 deletions benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pandas as pd
import torch
from tqdm import tqdm
from triton.testing import do_bench

from torchao.ops import (
rowwise_scaled_linear_cutlass_s4s4,
rowwise_scaled_linear_cutlass_s8s4,
)


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):
assert A_nbits in (4, 8) and B_nbits in (4, 8)

dev = torch.device("cuda")
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randint(
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A, A_scale, B, B_scale, C


def benchmark(m: int, k: int, n: int):
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4)
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
)

A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4)
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C
)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
"s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
"rowwise_scaled_linear_cutlass_s4s4 latency (ms)": rowwise_scaled_linear_cutlass_s4s4_time,
"s4s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s4s4_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("rowwise_scaled_linear_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))
52 changes: 0 additions & 52 deletions benchmarks/benchmark_s8s4_cutlass.py

This file was deleted.

36 changes: 24 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,30 +240,42 @@ def get_extensions():
extra_compile_args["nvcc"].append("-g")
extra_link_args.append("/DEBUG")

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)

if use_cuda:
sources += cuda_sources

use_cutlass = False
if use_cuda and not IS_WINDOWS:
use_cutlass = True
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
if use_cutlass:
extra_compile_args["nvcc"].extend(
[
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
"-I" + cutlass_extensions_include_dir,
]
)

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)

if use_cuda:
sources += cuda_sources
else:
# Remove CUTLASS-based kernels from the cuda_sources list. An
# assumption is that these files will have "cutlass" in its
# name.
cutlass_sources = list(
glob.glob(
os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True
)
)
sources = [s for s in sources if s not in cutlass_sources]

ext_modules = []
if len(sources) > 0:
Expand Down
2 changes: 2 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
Expand Down Expand Up @@ -61,6 +62,7 @@ def get_quantization_functions(
layout=CutlassInt4PackedLayout(),
)
)
base_functions.append(int4_dynamic_activation_int4_weight())

if do_sparse:
base_functions.append(
Expand Down
104 changes: 104 additions & 0 deletions test/test_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import itertools

import pytest
import torch

from torchao.ops import (
rowwise_scaled_linear_cutlass_s4s4,
rowwise_scaled_linear_cutlass_s8s4,
)
from torchao.quantization.utils import group_quantize_tensor_symmetric

ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True]
ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list(
itertools.product(
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE,
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE,
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK,
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS,
)
)


def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias):
assert xq_bits in [4, 8]
assert wq_bits in [4, 8]

size_m, size_n, size_k = size_mnk

x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
w = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None

x_2d = x.view(-1, x.shape[-1])
xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric(
x_2d, xq_bits, size_k, dtype
)
assert torch.all(xq_2d_zeros == 0)
xq_s8 = xq_2d_s8.reshape(x.shape)
if xq_bits == 4:
xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF)
else:
xq = xq_s8
xq_scales = xq_2d_scales.reshape(x.shape[:-1])

wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric(
w, wq_bits, size_n, dtype
)
assert torch.all(wq_zeros == 0)
if wq_bits == 4:
wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF)
else:
wq = wq_s8

# If torch.nn.functional.linear(x, w, bias) used as reference, the
# error would be too big. The calculation below is approximately
# what rowwise_scaled_linear_cutlass kernel is doing (except that
# matrix multiplication is over integers there).
size_m_2d = x_2d.shape[0]
output_ref = (
(xq_2d_s8.float() @ wq_s8.float().T)
* xq_2d_scales.view(size_m_2d, 1)
* wq_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,))

fn_inputs = (xq, xq_scales, wq, wq_scales, bias)
try:
output = op(*fn_inputs)
except NotImplementedError:
pytest.xfail("operator not implemented")

torch.testing.assert_close(output, output_ref)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
)
def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
run_test_for_op(
rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
)
def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
run_test_for_op(
rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias
)
77 changes: 0 additions & 77 deletions test/test_s8s4_linear_cutlass.py

This file was deleted.

3 changes: 2 additions & 1 deletion torchao/csrc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ The goal is that you can focus on just writing your custom CUDA or C++ kernel an

To learn more about custom ops in PyTorch you can refer to the [PyTorch Custom Operators Landing Page](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)


## How to add your own kernel in ao

We've integrated several kernels which you can use as a template for your own kernels. `tensor_core_tiled_layout` is the most straight-forward to get started with.
Expand All @@ -23,6 +22,8 @@ And that's it! Once CI passes and your code merged you'll be able to point peopl

If you'd like to learn more please check out [torch.library](https://pytorch.org/docs/main/library.html)

Note: All CUTLASS-based kernels should have `cutlass` in the name of their `.cu` files e.g. `rowwise_scaled_linear_cutlass_s4s4.cu`

## Required dependencies

The important dependencies are already taken care of in our CI so feel free to test in CI directly
Expand Down
Loading

0 comments on commit 1a4c8f9

Please sign in to comment.