-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add w4a4 * add test * hook up to AQT * fix quant api test * fix test * make threadblockswizzle a template param * re-use s8s4 cutlass template * add Alex's patch and some changes * fix aqt test * remove int4_cutlass.cu * apply alex's patch * update benchmark script * ruff * add some tuning * reduce num_stages to fit shared memory of small GPUs (<100kb) * replace torch timer with triton do_bench * ruff * use ZeroPointDomain.NONE * fix 3.7 typing * merge Aleksandar changes * run ruff * try replace torch/extension.h with torch/library.h * (alexsamardzic) improve error handling * ruff format * add note on cutlass naming
- Loading branch information
1 parent
b2fb664
commit 1a4c8f9
Showing
17 changed files
with
734 additions
and
444 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pandas as pd | ||
import torch | ||
from tqdm import tqdm | ||
from triton.testing import do_bench | ||
|
||
from torchao.ops import ( | ||
rowwise_scaled_linear_cutlass_s4s4, | ||
rowwise_scaled_linear_cutlass_s8s4, | ||
) | ||
|
||
|
||
def benchmark_microseconds(f, *args): | ||
return do_bench(lambda: f(*args), return_mode="median") * 1e3 | ||
|
||
|
||
def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int): | ||
assert A_nbits in (4, 8) and B_nbits in (4, 8) | ||
|
||
dev = torch.device("cuda") | ||
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev) | ||
A_scale = torch.randn((m,), dtype=torch.half, device=dev) | ||
B = torch.randint( | ||
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev | ||
) | ||
B_scale = torch.randn((n,), dtype=torch.half, device=dev) | ||
C = None | ||
|
||
return A, A_scale, B, B_scale, C | ||
|
||
|
||
def benchmark(m: int, k: int, n: int): | ||
dev = torch.device("cuda") | ||
A_ref = torch.randn((m, k), dtype=torch.half, device=dev) | ||
B_ref = torch.randn((n, k), dtype=torch.half, device=dev) | ||
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) | ||
|
||
A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4) | ||
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( | ||
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C | ||
) | ||
|
||
A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4) | ||
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds( | ||
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C | ||
) | ||
|
||
return { | ||
"m": m, | ||
"k": k, | ||
"n": n, | ||
"fp16_latency (ms)": fp16_time, | ||
"rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time, | ||
"s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time, | ||
"rowwise_scaled_linear_cutlass_s4s4 latency (ms)": rowwise_scaled_linear_cutlass_s4s4_time, | ||
"s4s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s4s4_time, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
k_vals = (8192, 8192, 8192, 28672) | ||
n_vals = (8192, 10240, 57344, 8192) | ||
|
||
results = [] | ||
for m in tqdm([1 << i for i in range(10)]): | ||
for n, k in zip(n_vals, k_vals): | ||
results.append(benchmark(m, k, n)) | ||
|
||
df = pd.DataFrame(results) | ||
df.to_csv("rowwise_scaled_linear_cutlass_time_results.csv", index=False) | ||
print(df.to_markdown(index=False)) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import itertools | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchao.ops import ( | ||
rowwise_scaled_linear_cutlass_s4s4, | ||
rowwise_scaled_linear_cutlass_s8s4, | ||
) | ||
from torchao.quantization.utils import group_quantize_tensor_symmetric | ||
|
||
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] | ||
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] | ||
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [ | ||
(2, 512, 128), | ||
(3, 2048, 2048), | ||
(4, 3584, 640), | ||
(13, 8704, 8576), | ||
(26, 18944, 1664), | ||
(67, 6656, 1408), | ||
] | ||
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True] | ||
ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list( | ||
itertools.product( | ||
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE, | ||
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE, | ||
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK, | ||
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS, | ||
) | ||
) | ||
|
||
|
||
def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias): | ||
assert xq_bits in [4, 8] | ||
assert wq_bits in [4, 8] | ||
|
||
size_m, size_n, size_k = size_mnk | ||
|
||
x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") | ||
w = torch.rand((size_n, size_k), dtype=dtype, device="cuda") | ||
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None | ||
|
||
x_2d = x.view(-1, x.shape[-1]) | ||
xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric( | ||
x_2d, xq_bits, size_k, dtype | ||
) | ||
assert torch.all(xq_2d_zeros == 0) | ||
xq_s8 = xq_2d_s8.reshape(x.shape) | ||
if xq_bits == 4: | ||
xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF) | ||
else: | ||
xq = xq_s8 | ||
xq_scales = xq_2d_scales.reshape(x.shape[:-1]) | ||
|
||
wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric( | ||
w, wq_bits, size_n, dtype | ||
) | ||
assert torch.all(wq_zeros == 0) | ||
if wq_bits == 4: | ||
wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF) | ||
else: | ||
wq = wq_s8 | ||
|
||
# If torch.nn.functional.linear(x, w, bias) used as reference, the | ||
# error would be too big. The calculation below is approximately | ||
# what rowwise_scaled_linear_cutlass kernel is doing (except that | ||
# matrix multiplication is over integers there). | ||
size_m_2d = x_2d.shape[0] | ||
output_ref = ( | ||
(xq_2d_s8.float() @ wq_s8.float().T) | ||
* xq_2d_scales.view(size_m_2d, 1) | ||
* wq_scales.view(1, size_n) | ||
) | ||
if bias is not None: | ||
output_ref += bias | ||
output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,)) | ||
|
||
fn_inputs = (xq, xq_scales, wq, wq_scales, bias) | ||
try: | ||
output = op(*fn_inputs) | ||
except NotImplementedError: | ||
pytest.xfail("operator not implemented") | ||
|
||
torch.testing.assert_close(output, output_ref) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.parametrize( | ||
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS | ||
) | ||
def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias): | ||
run_test_for_op( | ||
rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias | ||
) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.parametrize( | ||
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS | ||
) | ||
def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias): | ||
run_test_for_op( | ||
rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.