-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CUTLASS-based W4A4 #1515
Merged
Merged
Add CUTLASS-based W4A4 #1515
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
1d350d6
add w4a4
gau-nernst 7e277df
add test
gau-nernst a44df9e
hook up to AQT
gau-nernst 2487eb9
Merge branch 'main' into w4a4
gau-nernst de167f0
fix quant api test
gau-nernst fe1f0eb
fix test
gau-nernst 908f464
Merge branch 'main' into w4a4
gau-nernst 883384b
make threadblockswizzle a template param
gau-nernst ee34bb2
re-use s8s4 cutlass template
gau-nernst f513523
add Alex's patch and some changes
gau-nernst 9a1ce25
fix aqt test
gau-nernst b9db0f1
remove int4_cutlass.cu
gau-nernst f42fc65
apply alex's patch
gau-nernst a43f804
Merge branch 'main' into w4a4
gau-nernst 5c30303
update benchmark script
gau-nernst d7c0896
ruff
gau-nernst 2c5f565
Merge branch 'main' into w4a4
gau-nernst fd8dc4e
add some tuning
gau-nernst 5449a56
reduce num_stages to fit shared memory of small GPUs (<100kb)
gau-nernst c421921
replace torch timer with triton do_bench
gau-nernst 81a0a13
ruff
gau-nernst 69e6777
Merge branch 'main' into w4a4
gau-nernst c736856
use ZeroPointDomain.NONE
gau-nernst bdcb85c
fix 3.7 typing
gau-nernst 0c85805
Merge branch 'main' into w4a4
gau-nernst 4a19634
merge Aleksandar changes
gau-nernst 496cec8
run ruff
gau-nernst 9a0ae7b
try replace torch/extension.h with torch/library.h
gau-nernst 9332ac4
Merge branch 'main' into w4a4
gau-nernst 37dc5f7
(alexsamardzic) improve error handling
gau-nernst c003018
ruff format
gau-nernst 3b0b32b
add note on cutlass naming
gau-nernst 4613503
Merge branch 'main' into w4a4
gau-nernst File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pandas as pd | ||
import torch | ||
from tqdm import tqdm | ||
from triton.testing import do_bench | ||
|
||
from torchao.ops import ( | ||
rowwise_scaled_linear_cutlass_s4s4, | ||
rowwise_scaled_linear_cutlass_s8s4, | ||
) | ||
|
||
|
||
def benchmark_microseconds(f, *args): | ||
return do_bench(lambda: f(*args), return_mode="median") * 1e3 | ||
|
||
|
||
def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int): | ||
assert A_nbits in (4, 8) and B_nbits in (4, 8) | ||
|
||
dev = torch.device("cuda") | ||
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev) | ||
A_scale = torch.randn((m,), dtype=torch.half, device=dev) | ||
B = torch.randint( | ||
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev | ||
) | ||
B_scale = torch.randn((n,), dtype=torch.half, device=dev) | ||
C = None | ||
|
||
return A, A_scale, B, B_scale, C | ||
|
||
|
||
def benchmark(m: int, k: int, n: int): | ||
dev = torch.device("cuda") | ||
A_ref = torch.randn((m, k), dtype=torch.half, device=dev) | ||
B_ref = torch.randn((n, k), dtype=torch.half, device=dev) | ||
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) | ||
|
||
A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4) | ||
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( | ||
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C | ||
) | ||
|
||
A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4) | ||
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds( | ||
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C | ||
) | ||
|
||
return { | ||
"m": m, | ||
"k": k, | ||
"n": n, | ||
"fp16_latency (ms)": fp16_time, | ||
"rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time, | ||
"s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time, | ||
"rowwise_scaled_linear_cutlass_s4s4 latency (ms)": rowwise_scaled_linear_cutlass_s4s4_time, | ||
"s4s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s4s4_time, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
k_vals = (8192, 8192, 8192, 28672) | ||
n_vals = (8192, 10240, 57344, 8192) | ||
|
||
results = [] | ||
for m in tqdm([1 << i for i in range(10)]): | ||
for n, k in zip(n_vals, k_vals): | ||
results.append(benchmark(m, k, n)) | ||
|
||
df = pd.DataFrame(results) | ||
df.to_csv("rowwise_scaled_linear_cutlass_time_results.csv", index=False) | ||
print(df.to_markdown(index=False)) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import itertools | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchao.ops import ( | ||
rowwise_scaled_linear_cutlass_s4s4, | ||
rowwise_scaled_linear_cutlass_s8s4, | ||
) | ||
from torchao.quantization.utils import group_quantize_tensor_symmetric | ||
|
||
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] | ||
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] | ||
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [ | ||
(2, 512, 128), | ||
(3, 2048, 2048), | ||
(4, 3584, 640), | ||
(13, 8704, 8576), | ||
(26, 18944, 1664), | ||
(67, 6656, 1408), | ||
] | ||
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True] | ||
ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list( | ||
itertools.product( | ||
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE, | ||
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE, | ||
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK, | ||
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS, | ||
) | ||
) | ||
|
||
|
||
def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias): | ||
assert xq_bits in [4, 8] | ||
assert wq_bits in [4, 8] | ||
|
||
size_m, size_n, size_k = size_mnk | ||
|
||
x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") | ||
w = torch.rand((size_n, size_k), dtype=dtype, device="cuda") | ||
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None | ||
|
||
x_2d = x.view(-1, x.shape[-1]) | ||
xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric( | ||
x_2d, xq_bits, size_k, dtype | ||
) | ||
assert torch.all(xq_2d_zeros == 0) | ||
xq_s8 = xq_2d_s8.reshape(x.shape) | ||
if xq_bits == 4: | ||
xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF) | ||
else: | ||
xq = xq_s8 | ||
xq_scales = xq_2d_scales.reshape(x.shape[:-1]) | ||
|
||
wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric( | ||
w, wq_bits, size_n, dtype | ||
) | ||
assert torch.all(wq_zeros == 0) | ||
if wq_bits == 4: | ||
wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF) | ||
else: | ||
wq = wq_s8 | ||
|
||
# If torch.nn.functional.linear(x, w, bias) used as reference, the | ||
# error would be too big. The calculation below is approximately | ||
# what rowwise_scaled_linear_cutlass kernel is doing (except that | ||
# matrix multiplication is over integers there). | ||
size_m_2d = x_2d.shape[0] | ||
output_ref = ( | ||
(xq_2d_s8.float() @ wq_s8.float().T) | ||
* xq_2d_scales.view(size_m_2d, 1) | ||
* wq_scales.view(1, size_n) | ||
) | ||
if bias is not None: | ||
output_ref += bias | ||
output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,)) | ||
|
||
fn_inputs = (xq, xq_scales, wq, wq_scales, bias) | ||
try: | ||
output = op(*fn_inputs) | ||
except NotImplementedError: | ||
pytest.xfail("operator not implemented") | ||
|
||
torch.testing.assert_close(output, output_ref) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.parametrize( | ||
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS | ||
) | ||
def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias): | ||
run_test_for_op( | ||
rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias | ||
) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.parametrize( | ||
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS | ||
) | ||
def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias): | ||
run_test_for_op( | ||
rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexsamardzic I replaced
torchao.utils.benchmark_torch_function_in_microseconds
(which is based ontorch.utils.benchmark.Timer
) with triton'sdo_bench
. This is because I found PyTorch timer is unreliable, possibly because it does not clear L2 cache in between runs.Old (4090,
torch.utils.benchmark.Timer
)New (4090,
triton.testing.do_bench
)In the old way, you can see the unusual speedup for the first two rows. I think it's because the W is cached in L2, hence the gains disappear when W becomes larger.
Lmk if it's ok to have this change. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that's great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For anyone else I break it down as: