Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUTLASS-based W4A4 #1515

Merged
merged 33 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1d350d6
add w4a4
gau-nernst Jan 7, 2025
7e277df
add test
gau-nernst Jan 7, 2025
a44df9e
hook up to AQT
gau-nernst Jan 8, 2025
2487eb9
Merge branch 'main' into w4a4
gau-nernst Jan 8, 2025
de167f0
fix quant api test
gau-nernst Jan 8, 2025
fe1f0eb
fix test
gau-nernst Jan 8, 2025
908f464
Merge branch 'main' into w4a4
gau-nernst Jan 22, 2025
883384b
make threadblockswizzle a template param
gau-nernst Jan 22, 2025
ee34bb2
re-use s8s4 cutlass template
gau-nernst Jan 22, 2025
f513523
add Alex's patch and some changes
gau-nernst Jan 23, 2025
9a1ce25
fix aqt test
gau-nernst Jan 23, 2025
b9db0f1
remove int4_cutlass.cu
gau-nernst Jan 23, 2025
f42fc65
apply alex's patch
gau-nernst Jan 24, 2025
a43f804
Merge branch 'main' into w4a4
gau-nernst Jan 24, 2025
5c30303
update benchmark script
gau-nernst Jan 24, 2025
d7c0896
ruff
gau-nernst Jan 24, 2025
2c5f565
Merge branch 'main' into w4a4
gau-nernst Jan 26, 2025
fd8dc4e
add some tuning
gau-nernst Jan 26, 2025
5449a56
reduce num_stages to fit shared memory of small GPUs (<100kb)
gau-nernst Jan 26, 2025
c421921
replace torch timer with triton do_bench
gau-nernst Jan 26, 2025
81a0a13
ruff
gau-nernst Jan 26, 2025
69e6777
Merge branch 'main' into w4a4
gau-nernst Jan 30, 2025
c736856
use ZeroPointDomain.NONE
gau-nernst Jan 30, 2025
bdcb85c
fix 3.7 typing
gau-nernst Jan 30, 2025
0c85805
Merge branch 'main' into w4a4
gau-nernst Feb 1, 2025
4a19634
merge Aleksandar changes
gau-nernst Feb 1, 2025
496cec8
run ruff
gau-nernst Feb 1, 2025
9a0ae7b
try replace torch/extension.h with torch/library.h
gau-nernst Feb 1, 2025
9332ac4
Merge branch 'main' into w4a4
gau-nernst Feb 2, 2025
37dc5f7
(alexsamardzic) improve error handling
gau-nernst Feb 2, 2025
c003018
ruff format
gau-nernst Feb 2, 2025
3b0b32b
add note on cutlass naming
gau-nernst Feb 4, 2025
4613503
Merge branch 'main' into w4a4
gau-nernst Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pandas as pd
import torch
from tqdm import tqdm
from triton.testing import do_bench

from torchao.ops import (
rowwise_scaled_linear_cutlass_s4s4,
rowwise_scaled_linear_cutlass_s8s4,
)


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3
Comment on lines +12 to +13
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexsamardzic I replaced torchao.utils.benchmark_torch_function_in_microseconds (which is based on torch.utils.benchmark.Timer) with triton's do_bench. This is because I found PyTorch timer is unreliable, possibly because it does not clear L2 cache in between runs.

Old (4090, torch.utils.benchmark.Timer)

m k n fp16_latency (ms) W4A8 latency (ms) W4A8 speedup (d/s)
1 8192 8192 124.368 18.6586 6.66545
1 8192 10240 177.832 22.1851 8.01584
1 8192 57344 982.894 298.987 3.28742
1 28672 8192 493.988 152.724 3.23452

New (4090, triton.testing.do_bench)

m k n fp16_latency (ms) W4A8 latency (ms) W4A8 speedup (d/s)
1 8192 8192 166.912 73.728 2.26389
1 8192 10240 201.824 87.04 2.31875
1 8192 57344 1005.57 346.112 2.90533
1 28672 8192 519.136 199.68 2.59984

In the old way, you can see the unusual speedup for the first two rows. I think it's because the W is cached in L2, hence the gains disappear when W becomes larger.

Lmk if it's ok to have this change. Thank you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that's great!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For anyone else I break it down as:

  1. Do you only care about kernel time? Then use: https://github.com/drisspg/transformer_nuggets/blob/0d1cbe6b3d980451f39174be64c98a143b9cebce/transformer_nuggets/utils/benchmark.py#L55. This is what I have found produces the closest results to NCU
  2. If the overhead does matter and you are more focused on what you would see on a an isolated region of code then the timer is pretty representative https://github.com/drisspg/transformer_nuggets/blob/0d1cbe6b3d980451f39174be64c98a143b9cebce/transformer_nuggets/utils/benchmark.py#L44



def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):
assert A_nbits in (4, 8) and B_nbits in (4, 8)

dev = torch.device("cuda")
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randint(
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A, A_scale, B, B_scale, C


def benchmark(m: int, k: int, n: int):
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4)
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
)

A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4)
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C
)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
"s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
"rowwise_scaled_linear_cutlass_s4s4 latency (ms)": rowwise_scaled_linear_cutlass_s4s4_time,
"s4s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s4s4_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("rowwise_scaled_linear_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))
52 changes: 0 additions & 52 deletions benchmarks/benchmark_s8s4_cutlass.py

This file was deleted.

2 changes: 2 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
Expand Down Expand Up @@ -61,6 +62,7 @@ def get_quantization_functions(
layout=CutlassInt4PackedLayout(),
)
)
base_functions.append(int4_dynamic_activation_int4_weight())

if do_sparse:
base_functions.append(
Expand Down
104 changes: 104 additions & 0 deletions test/test_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import itertools

import pytest
import torch

from torchao.ops import (
rowwise_scaled_linear_cutlass_s4s4,
rowwise_scaled_linear_cutlass_s8s4,
)
from torchao.quantization.utils import group_quantize_tensor_symmetric

ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True]
ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list(
itertools.product(
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE,
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE,
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK,
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS,
)
)


def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias):
assert xq_bits in [4, 8]
assert wq_bits in [4, 8]

size_m, size_n, size_k = size_mnk

x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
w = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None

x_2d = x.view(-1, x.shape[-1])
xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric(
x_2d, xq_bits, size_k, dtype
)
assert torch.all(xq_2d_zeros == 0)
xq_s8 = xq_2d_s8.reshape(x.shape)
if xq_bits == 4:
xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF)
else:
xq = xq_s8
xq_scales = xq_2d_scales.reshape(x.shape[:-1])

wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric(
w, wq_bits, size_n, dtype
)
assert torch.all(wq_zeros == 0)
if wq_bits == 4:
wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF)
else:
wq = wq_s8

# If torch.nn.functional.linear(x, w, bias) used as reference, the
# error would be too big. The calculation below is approximately
# what rowwise_scaled_linear_cutlass kernel is doing (except that
# matrix multiplication is over integers there).
size_m_2d = x_2d.shape[0]
output_ref = (
(xq_2d_s8.float() @ wq_s8.float().T)
* xq_2d_scales.view(size_m_2d, 1)
* wq_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,))

fn_inputs = (xq, xq_scales, wq, wq_scales, bias)
try:
output = op(*fn_inputs)
except NotImplementedError:
pytest.xfail("operator not implemented")

torch.testing.assert_close(output, output_ref)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
)
def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
run_test_for_op(
rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
)
def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
run_test_for_op(
rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias
)
77 changes: 0 additions & 77 deletions test/test_s8s4_linear_cutlass.py

This file was deleted.

52 changes: 52 additions & 0 deletions torchao/csrc/cuda/rowwise_scaled_linear_cutlass/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
This directory is intended to contain implementations for all of the
CUTLASS-based row-wise scaled linear operators, for non-sparse inputs
of both same and mixed data types.

The implementation is through single kernel per SM generation, that
should reside in `rowwise_scaled_linear_kernel_cutlass.cuh` file. At
the moment, only SM8.x architectures are supported, through
`rowwise_scaled_linear_kernel_cutlass_sm8x` kernel, but the SM9.x, and
eventually higher, can and will be supported too.

The rest of source files, besides
`rowwise_scaled_linear_kernel_cutlass.cuh` file, contain just the
corresponding template instantiation and PyTorch operator declaration
for given operator.

In order to support new combination of data types, copy one of
existing `.cu` files, for example
`rowwise_scaled_linear_kernel_cutlass_s8s4.cu`, rename the new file,
as well as operator to be defined inside, to reflect data types to be
supported, and also change `using ElementA` and `using ElementB`
directives accordingly.

In the `.cuh` file, looking from the bottom up, the changes needed as
follows:

1. Optionally, in the `rowwise_scaled_linear_cutlass_check_inputs`
template, changes may be needed at the places where the last dimension
of first operand is checked - but this check will have to be updated
only for inputs of mixed data types, where wider data type is not
exactly two times wider than the other data type.
2. In the `select_config` template, a section should be added to
choose optimal configuration(s) for your kernel. The configuration
selection is critical for performance of any CUTLASS-based kernel, so
this is where the most time should and will be spent when making
changes.
3. Optionally, in the `rowwise_scaled_linear_kernel_cutlass_sm8x`
template, `using Operator` directive may need to be adjusted; namely,
for some combination of operands, `OpMultiplyAdd` may have to be used.

After making these changes, the test file
`tests/test_rowwise_scaled_linear_cutlass.py` should be changed too -
add a test for the new operator alike to existing tests.

To restrict build times, the implementation in `.cuh` file has some
restrictions at the moment, for example: scale tensors could be only
of `float16` or `bfloat16` data types, the output is produces to be of
the same data type as first input scale tensor, scale tensors are not
optional while bias is optional, etc. If any of these restrictions
should be removed, or if any alike changes are needed, or if support
for other architectures is needed, or if you need any kind of help in
extending this code to support other data type combinations - get in
touch with the developers.
Loading
Loading