Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUTLASS-based W4A4 #1515

Merged
merged 33 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1d350d6
add w4a4
gau-nernst Jan 7, 2025
7e277df
add test
gau-nernst Jan 7, 2025
a44df9e
hook up to AQT
gau-nernst Jan 8, 2025
2487eb9
Merge branch 'main' into w4a4
gau-nernst Jan 8, 2025
de167f0
fix quant api test
gau-nernst Jan 8, 2025
fe1f0eb
fix test
gau-nernst Jan 8, 2025
908f464
Merge branch 'main' into w4a4
gau-nernst Jan 22, 2025
883384b
make threadblockswizzle a template param
gau-nernst Jan 22, 2025
ee34bb2
re-use s8s4 cutlass template
gau-nernst Jan 22, 2025
f513523
add Alex's patch and some changes
gau-nernst Jan 23, 2025
9a1ce25
fix aqt test
gau-nernst Jan 23, 2025
b9db0f1
remove int4_cutlass.cu
gau-nernst Jan 23, 2025
f42fc65
apply alex's patch
gau-nernst Jan 24, 2025
a43f804
Merge branch 'main' into w4a4
gau-nernst Jan 24, 2025
5c30303
update benchmark script
gau-nernst Jan 24, 2025
d7c0896
ruff
gau-nernst Jan 24, 2025
2c5f565
Merge branch 'main' into w4a4
gau-nernst Jan 26, 2025
fd8dc4e
add some tuning
gau-nernst Jan 26, 2025
5449a56
reduce num_stages to fit shared memory of small GPUs (<100kb)
gau-nernst Jan 26, 2025
c421921
replace torch timer with triton do_bench
gau-nernst Jan 26, 2025
81a0a13
ruff
gau-nernst Jan 26, 2025
69e6777
Merge branch 'main' into w4a4
gau-nernst Jan 30, 2025
c736856
use ZeroPointDomain.NONE
gau-nernst Jan 30, 2025
bdcb85c
fix 3.7 typing
gau-nernst Jan 30, 2025
0c85805
Merge branch 'main' into w4a4
gau-nernst Feb 1, 2025
4a19634
merge Aleksandar changes
gau-nernst Feb 1, 2025
496cec8
run ruff
gau-nernst Feb 1, 2025
9a0ae7b
try replace torch/extension.h with torch/library.h
gau-nernst Feb 1, 2025
9332ac4
Merge branch 'main' into w4a4
gau-nernst Feb 2, 2025
37dc5f7
(alexsamardzic) improve error handling
gau-nernst Feb 2, 2025
c003018
ruff format
gau-nernst Feb 2, 2025
3b0b32b
add note on cutlass naming
gau-nernst Feb 4, 2025
4613503
Merge branch 'main' into w4a4
gau-nernst Feb 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pandas as pd
import torch
from tqdm import tqdm
from triton.testing import do_bench

from torchao.ops import (
rowwise_scaled_linear_cutlass_s4s4,
rowwise_scaled_linear_cutlass_s8s4,
)


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3
Comment on lines +12 to +13
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexsamardzic I replaced torchao.utils.benchmark_torch_function_in_microseconds (which is based on torch.utils.benchmark.Timer) with triton's do_bench. This is because I found PyTorch timer is unreliable, possibly because it does not clear L2 cache in between runs.

Old (4090, torch.utils.benchmark.Timer)

m k n fp16_latency (ms) W4A8 latency (ms) W4A8 speedup (d/s)
1 8192 8192 124.368 18.6586 6.66545
1 8192 10240 177.832 22.1851 8.01584
1 8192 57344 982.894 298.987 3.28742
1 28672 8192 493.988 152.724 3.23452

New (4090, triton.testing.do_bench)

m k n fp16_latency (ms) W4A8 latency (ms) W4A8 speedup (d/s)
1 8192 8192 166.912 73.728 2.26389
1 8192 10240 201.824 87.04 2.31875
1 8192 57344 1005.57 346.112 2.90533
1 28672 8192 519.136 199.68 2.59984

In the old way, you can see the unusual speedup for the first two rows. I think it's because the W is cached in L2, hence the gains disappear when W becomes larger.

Lmk if it's ok to have this change. Thank you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that's great!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For anyone else I break it down as:

  1. Do you only care about kernel time? Then use: https://github.com/drisspg/transformer_nuggets/blob/0d1cbe6b3d980451f39174be64c98a143b9cebce/transformer_nuggets/utils/benchmark.py#L55. This is what I have found produces the closest results to NCU
  2. If the overhead does matter and you are more focused on what you would see on a an isolated region of code then the timer is pretty representative https://github.com/drisspg/transformer_nuggets/blob/0d1cbe6b3d980451f39174be64c98a143b9cebce/transformer_nuggets/utils/benchmark.py#L44



def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):
assert A_nbits in (4, 8) and B_nbits in (4, 8)

dev = torch.device("cuda")
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randint(
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A, A_scale, B, B_scale, C


def benchmark(m: int, k: int, n: int):
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4)
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
)

A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4)
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C
)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
"s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
"rowwise_scaled_linear_cutlass_s4s4 latency (ms)": rowwise_scaled_linear_cutlass_s4s4_time,
"s4s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s4s4_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("rowwise_scaled_linear_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))
52 changes: 0 additions & 52 deletions benchmarks/benchmark_s8s4_cutlass.py

This file was deleted.

36 changes: 24 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,30 +240,42 @@ def get_extensions():
extra_compile_args["nvcc"].append("-g")
extra_link_args.append("/DEBUG")

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)

if use_cuda:
sources += cuda_sources

use_cutlass = False
if use_cuda and not IS_WINDOWS:
use_cutlass = True
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
if use_cutlass:
extra_compile_args["nvcc"].extend(
[
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
"-I" + cutlass_extensions_include_dir,
]
)

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)

if use_cuda:
sources += cuda_sources
else:
# Remove CUTLASS-based kernels from the cuda_sources list. An
# assumption is that these files will have "cutlass" in its
# name.
drisspg marked this conversation as resolved.
Show resolved Hide resolved
cutlass_sources = list(
glob.glob(
os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True
)
)
sources = [s for s in sources if s not in cutlass_sources]

ext_modules = []
if len(sources) > 0:
Expand Down
2 changes: 2 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
Expand Down Expand Up @@ -61,6 +62,7 @@ def get_quantization_functions(
layout=CutlassInt4PackedLayout(),
)
)
base_functions.append(int4_dynamic_activation_int4_weight())

if do_sparse:
base_functions.append(
Expand Down
104 changes: 104 additions & 0 deletions test/test_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import itertools

import pytest
import torch

from torchao.ops import (
rowwise_scaled_linear_cutlass_s4s4,
rowwise_scaled_linear_cutlass_s8s4,
)
from torchao.quantization.utils import group_quantize_tensor_symmetric

ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True]
ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list(
itertools.product(
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE,
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE,
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK,
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS,
)
)


def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias):
assert xq_bits in [4, 8]
assert wq_bits in [4, 8]

size_m, size_n, size_k = size_mnk

x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
w = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None

x_2d = x.view(-1, x.shape[-1])
xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric(
x_2d, xq_bits, size_k, dtype
)
assert torch.all(xq_2d_zeros == 0)
xq_s8 = xq_2d_s8.reshape(x.shape)
if xq_bits == 4:
xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF)
else:
xq = xq_s8
xq_scales = xq_2d_scales.reshape(x.shape[:-1])

wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric(
w, wq_bits, size_n, dtype
)
assert torch.all(wq_zeros == 0)
if wq_bits == 4:
wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF)
else:
wq = wq_s8

# If torch.nn.functional.linear(x, w, bias) used as reference, the
# error would be too big. The calculation below is approximately
# what rowwise_scaled_linear_cutlass kernel is doing (except that
# matrix multiplication is over integers there).
size_m_2d = x_2d.shape[0]
output_ref = (
(xq_2d_s8.float() @ wq_s8.float().T)
* xq_2d_scales.view(size_m_2d, 1)
* wq_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,))

fn_inputs = (xq, xq_scales, wq, wq_scales, bias)
try:
output = op(*fn_inputs)
except NotImplementedError:
pytest.xfail("operator not implemented")

torch.testing.assert_close(output, output_ref)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
)
def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
run_test_for_op(
rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
)
def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
run_test_for_op(
rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias
)
77 changes: 0 additions & 77 deletions test/test_s8s4_linear_cutlass.py

This file was deleted.

3 changes: 2 additions & 1 deletion torchao/csrc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ The goal is that you can focus on just writing your custom CUDA or C++ kernel an

To learn more about custom ops in PyTorch you can refer to the [PyTorch Custom Operators Landing Page](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)


## How to add your own kernel in ao

We've integrated several kernels which you can use as a template for your own kernels. `tensor_core_tiled_layout` is the most straight-forward to get started with.
Expand All @@ -23,6 +22,8 @@ And that's it! Once CI passes and your code merged you'll be able to point peopl

If you'd like to learn more please check out [torch.library](https://pytorch.org/docs/main/library.html)

Note: All CUTLASS-based kernels should have `cutlass` in the name of their `.cu` files e.g. `rowwise_scaled_linear_cutlass_s4s4.cu`

## Required dependencies

The important dependencies are already taken care of in our CI so feel free to test in CI directly
Expand Down
Loading
Loading