-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CUTLASS-based W4A4 #1515
Add CUTLASS-based W4A4 #1515
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1515
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 4613503 with merge base 7e54629 (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
CUDA code looks fine, of course there are lots of dots to connect remaining on the Python side. The difference from #880 is that this is not mixed data types GEMM, but regular GEMM instead. In that regard, this operator here is maybe easier to be made much more generic, to support other integer and maybe even some floating point input data types. I'm at the moment making some minor changes on this PyTorch operator, and would strongly recommend modelling CUDA code in alike way, as it plain looks nice, and then makes extending the kernel to other datatypes much easier, has extensive checks on operands, etc. Moreover, I think it would make sense at this point to discuss having a single CUTLASS-based kernel for GEMMs with both weights and activations scaled, to be put in the single source file, and to handle both same and mixed data types GEMMs, at least for SM 8.x archs - that would provide for minimum code duplication, and easier maintenance in the future. As far as configurations (tile sizes, number of stages, etc.) concerned, I'd suggest looking here instead in the unit tests, and also comparing performance vs. results reported by CUTLASS profiler for given combination of data types. I believe some sort of tuning configuration on the input shapes is a must in order to achieve a decent performance; but I have to admit that in #880 the tuning is mostly ad-hoc (for comparison, I find this approach more elaborate and meaningful). Thus, I think that coming up with some kind of systematic approach in that regard would be the most beneficial contribution regarding eventual future use of CUTLASS-based kernels in the torchao. (@drisspg: Your comments welcome here.) |
One thing on finding optimal params is that @yifuwang was recently working on finding better configs for an AsyncMM. He did some manual elimination of configs that never seemed to be performant and then fit a simple decision Tree on a big sweep over MKN shapes that could be easily modeled in C++. This is similar to what is done in the RowWise scaling. I think a little flow for this would be helpful I can make an issue to track. No major comments |
Thank you for the feedback.
Though this is nice on paper, I think Triton is the better alternative for other data types (INT8, FP8...). It's more flexible and the autotuner also saves us some headache. Only because of the lack of INT4 support in Triton, we have to use Cutlass, especially for INT4 Tensor cores. Unless we can show that there are cases Triton cannot reach the perf of Cutlass (in the context of this PR, I'm only thinking about INT8 for SM8x, and additionally FP8 for SM89). Having said that, I'm ok with following a certain style/structure. Just point me which one it should be, and I will make modifications accordingly. |
torchao/ops.py
Outdated
Returns: | ||
output: result tensor, in row-major layout. | ||
""" | ||
assert A.dtype == B.dtype == torch.int8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add the alignment constraints as well right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How should I check for data alignment from Python? I guess in C++, I can check by testing divisibility of the memory address? (or perhaps there is a util function somewhere that I'm not aware of...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I think there is a restriction that k need to be a multiple of 32 right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or at least 16 packed int4 s
torchao/csrc/cuda/int4_cutlass.cu
Outdated
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; | ||
// static int const kStages = 3; | ||
using ElementC = int32_t; | ||
using Gemm = cutlass::gemm::device::Gemm< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if the universal gemm api can be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will look into it. I wrote this quite some time ago...
Attached is a minor patch that will change The structure of CUTLASS-based kernels is typically always the same (see also rowwise scaled MM in PyTorch, mentioned in my previous comment, as well as my CUTLASS-based mixed data types and 2:4 sparsity kernels in PyTorch): from the bottom up, there is always an operator implementation function that contains checking inputs, and then starting a dispatching chain (where run-time data types etc. are translated to compile-time template arguments), that ends up with a typical CUTLASS-based GEMM kernel (that is boilerplate). Also as mentioned in my previous comment, while rowwise scaled MM is very similar in structure, I like how it looks the most - because of clever use of variable template arguments to decrease the clutter, then because of clear extraction of input checks, and configuration selection into separate functions, etc. So I'd suggest we have your C++ code integrated in the way sketched by attached diff, and then also to made minor changes in the C++ code in a way to make it to look closer to rowwise scaled MM implementation. (Of course, operator name and some other stuff on Python side will have to be changed too.) |
As far as performance between various implementations concerned: I'd say in general there are three ways to implement kernels: Triton-based, CUTLASS-based, and custom i.e. from scratch (like Marlin-based kernels). In my experience so far (that was all for Ampere arch), CUTLASS-based kernels are oftentimes somewhat faster than Triton-based kernels, while then for some corner-case input tensor sizes, custom kernels (well, Marlin-based at least) could be significantly faster than CUTLASS-based ones. Furthermore, with Triton there is the least amount of flexibility with upstream changes (they just don't support some input data types, they don't support 2:4 sparsity, etc.), with CUTLASS it's somewhat easier to have changes we may need accepted, while for custom kernels obviously this is not an issue at all. However, Triton kills it when it comes to compilation, in particular regarding fusing GEMM with other kernels, then CUTLASS has some support for compilation but doing fusion is rather cumbersome at the moment, while obviously there is no any kind of compilation support for custom kernels. Then, doing custom kernels would probably lead to lots of code duplication, with CUTLASS this also may be an issue even if to the smaller extent. Etc. - so it's all matter of trade-offs. Still, having in mind auto-tuning and auto-quantization, I belive it still may be good to have as much different kernels in torchao as possible, so I'd expect more CUTLASS-based kernels to be written, besides these W4A8 and W4A4 kernels - and this is the exact reason that, as discussed above, I'd prefer to have as much code shared as possible between these kernels. |
Might be interesting to try out QAT with this setting cc @andrewor14 |
I've made these changes to existing CUTLASS-based W4A8 kernel in #1545, so it should be easier now to eventually include W4A4 functionality there. |
@alexsamardzic Thank you for your prompt feedback. Was going to work on this a little bit more before asking for another round of review. I wasn't sure how much code-sharing you intend to have. Was trying to work with single file, but the biggest blocker is this one if (tensor_a.scalar_type() == at::ScalarType::Char) {
if (tensor_b.scalar_type() == at::ScalarType::Char) {
if (tensor_a.size(1) == 2 * tensor_b.size(1)) { In That's why I decided to move the inner-most template to a separate header file for sharing. Wanted to move I have another question. Currently you have
Side comments. I don't mind code duplication. Realistically speaking, there are not that many useful combinations, so code duplication is not too bad. And you probably already know, for cutlass device-level API, not all combinations work (and when they don't work, there are mysterious errors 😅) |
I realized, while making changes, that your point on dispatching on input/weight tensor data types is right. Namely, PyTorch doesn't provide sub-byte data types, so the only way to differentiate between data type combinations like S4/S4 and S8/S8 is to have this information encoded somehow else. Your approach was to encode it in the name of the operator, and this kind of approach could also solve the issue that I was worried about - that not all templates get instantiated from the single source file. I made some further changes to maximize C++ code reuse. I don't mind code duplication either, but in this case it's (in my opinion) about really cumbersome boilerplate code, that for sanity I'd really much prefer to keep on single place. So I came up with a following patch, to be applied on top of current state of this branch: Patchdiff --git a/benchmarks/benchmark_s8s4_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
similarity index 73%
rename from benchmarks/benchmark_s8s4_cutlass.py
rename to benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
index fbf07eb..00bcb0a 100644
--- a/benchmarks/benchmark_s8s4_cutlass.py
+++ b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
@@ -2,7 +2,7 @@ import pandas as pd
import torch
from tqdm import tqdm
-from torchao.ops import s8s4_linear_cutlass
+from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
from torchao.utils import benchmark_torch_function_in_microseconds
@@ -24,8 +24,8 @@ def benchmark(m: int, k: int, n: int):
A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k)
fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
- s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds(
- s8s4_linear_cutlass, A, A_scale, B, B_scale, C
+ rowwise_scaled_linear_cutlass_s8s4_time = benchmark_torch_function_in_microseconds(
+ rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
)
return {
@@ -33,8 +33,8 @@ def benchmark(m: int, k: int, n: int):
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
- "s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time,
- "speedup (d/s)": fp16_time / s8s4_linear_cutlass_time,
+ "rowwise_scaled_linear_cutlass latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
+ "speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
}
@@ -48,5 +48,5 @@ if __name__ == "__main__":
results.append(benchmark(m, k, n))
df = pd.DataFrame(results)
- df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False)
+ df.to_csv("rowwise_scaled_linear_cutlass_s8s4_time_results.csv", index=False)
print(df.to_markdown(index=False))
diff --git a/test/test_rowwise_scaled_linear_cutlass.py b/test/test_rowwise_scaled_linear_cutlass.py
new file mode 100644
index 0000000..422b100
--- /dev/null
+++ b/test/test_rowwise_scaled_linear_cutlass.py
@@ -0,0 +1,128 @@
+import itertools
+
+import pytest
+import torch
+
+from torchao.ops import (
+ rowwise_scaled_linear_cutlass_s4s4,
+ rowwise_scaled_linear_cutlass_s8s4,
+)
+from torchao.quantization.utils import group_quantize_tensor_symmetric
+from torchao.utils import compute_max_diff
+
+ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
+ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
+ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
+ (2, 512, 128),
+ (3, 2048, 2048),
+ (4, 3584, 640),
+ (13, 8704, 8576),
+ (26, 18944, 1664),
+ (67, 6656, 1408),
+]
+ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True]
+ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list(
+ itertools.product(
+ ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE,
+ ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE,
+ ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK,
+ ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS,
+ )
+)
+
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
+ "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
+)
+def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
+ size_m, size_n, size_k = size_mnk
+
+ input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
+ weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
+ bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
+
+ input_2d = input.view(-1, input.shape[-1])
+ input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
+ input_2d, 4, size_k, dtype
+ )
+ assert torch.all(input_2d_zeros == 0)
+ input_s8 = input_2d_s8.reshape(input.shape)
+ input_s4 = ((input_s8[:, :, 1::2] & 0xF) << 4) | (input_s8[:, :, 0::2] & 0xF)
+ input_scales = input_2d_scales.reshape(input.shape[:-1])
+
+ weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
+ weight, 4, size_n, dtype
+ )
+ assert torch.all(weight_zeros == 0)
+ weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
+
+ # If torch.nn.functional.linear(input, weight, bias) used as
+ # reference, the error would be too big. The calculation below is
+ # approximately what rowwise_scaled_linear_cutlass kernel is doing
+ # (except that matrix multiplication is over integers there)).
+ size_m_2d = input_2d.shape[0]
+ output_ref = (
+ (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
+ * input_2d_scales.view(size_m_2d, 1)
+ * weight_scales.view(1, size_n)
+ )
+ if bias is not None:
+ output_ref += bias
+ output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
+
+ fn_inputs = (input_s4, input_scales, weight_s4, weight_scales, bias)
+ try:
+ output = rowwise_scaled_linear_cutlass_s4s4(*fn_inputs)
+ except NotImplementedError:
+ pytest.xfail("rowwise_scaled_linear_cutlass() op not implemented")
+
+ max_diff = compute_max_diff(output, output_ref)
+ assert max_diff < 1e-3
+
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
+ "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
+)
+def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
+ size_m, size_n, size_k = size_mnk
+
+ input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
+ weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
+ bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
+
+ input_2d = input.view(-1, input.shape[-1])
+ input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
+ input_2d, 8, size_k, dtype
+ )
+ assert torch.all(input_2d_zeros == 0)
+ input_s8 = input_2d_s8.reshape(input.shape)
+ input_scales = input_2d_scales.reshape(input.shape[:-1])
+
+ weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
+ weight, 4, size_n, dtype
+ )
+ assert torch.all(weight_zeros == 0)
+ weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
+
+ # If torch.nn.functional.linear(input, weight, bias) used as
+ # reference, the error would be too big. The calculation below is
+ # approximately what rowwise_scaled_linear_cutlass kernel is doing
+ # (except that matrix multiplication is over integers there)).
+ size_m_2d = input_2d.shape[0]
+ output_ref = (
+ (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
+ * input_2d_scales.view(size_m_2d, 1)
+ * weight_scales.view(1, size_n)
+ )
+ if bias is not None:
+ output_ref += bias
+ output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
+
+ fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
+ try:
+ output = rowwise_scaled_linear_cutlass_s8s4(*fn_inputs)
+ except NotImplementedError:
+ pytest.xfail("rowwise_scaled_linear_cutlass() op not implemented")
+
+ max_diff = compute_max_diff(output, output_ref)
+ assert max_diff < 5e-3
diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py
deleted file mode 100644
index 6510ada..0000000
--- a/test/test_s8s4_linear_cutlass.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import itertools
-
-import pytest
-import torch
-
-from torchao.ops import s8s4_linear_cutlass
-from torchao.quantization.utils import group_quantize_tensor_symmetric
-from torchao.utils import compute_max_diff
-
-S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
-S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
-S8S4_LINEAR_CUTLASS_SIZE_MNK = [
- (2, 512, 128),
- (3, 2048, 2048),
- (4, 3584, 640),
- (13, 8704, 8576),
- (26, 18944, 1664),
- (67, 6656, 1408),
-]
-S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True]
-S8S4_LINEAR_CUTLASS_TEST_PARAMS = list(
- itertools.product(
- S8S4_LINEAR_CUTLASS_DTYPE,
- S8S4_LINEAR_CUTLASS_BATCH_SIZE,
- S8S4_LINEAR_CUTLASS_SIZE_MNK,
- S8S4_LINEAR_CUTLASS_USE_BIAS,
- )
-)
-
-
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
- "dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS
-)
-def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias):
- size_m, size_n, size_k = size_mnk
-
- input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
- weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
- bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
-
- input_2d = input.view(-1, input.shape[-1])
- input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
- input_2d, 8, size_k, dtype
- )
- assert torch.all(input_2d_zeros == 0)
- input_s8 = input_2d_s8.reshape(input.shape)
- input_scales = input_2d_scales.reshape(input.shape[:-1])
-
- weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
- weight, 4, size_n, dtype
- )
- assert torch.all(weight_zeros == 0)
- weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
-
- # If torch.nn.functional.linear(input, weight, bias) used as
- # reference, the error would be too big. The calculation below is
- # approximately what s8s4_linear_cutlass kernel is doing (except
- # that matrrix multiplication is over integers there)).
- size_m_2d = input_2d.shape[0]
- output_ref = (
- (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
- * input_2d_scales.view(size_m_2d, 1)
- * weight_scales.view(1, size_n)
- )
- if bias is not None:
- output_ref += bias
- output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
-
- fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
- try:
- output = s8s4_linear_cutlass(*fn_inputs)
- except NotImplementedError:
- pytest.xfail("s8s4_linear_cutlass() op not implemented")
-
- max_diff = compute_max_diff(output, output_ref)
- assert max_diff < 5e-3
diff --git a/torchao/csrc/cuda/int4_cutlass.cu b/torchao/csrc/cuda/int4_cutlass.cu
deleted file mode 100644
index 452abcc..0000000
--- a/torchao/csrc/cuda/int4_cutlass.cu
+++ /dev/null
@@ -1,231 +0,0 @@
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-
-// copied from s8s4_linear_cutlass.cu
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
- defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_INT4_MM_CUTLASS
-#endif
-
-#if defined(BUILD_INT4_MM_CUTLASS)
-#include "cutlass/cutlass.h"
-#include "cutlass/gemm/device/gemm_universal.h"
-#include "cutlass/gemm/device/gemm.h"
-#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
-#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
-#include "cutlass/gemm/device/gemm_universal_adapter.h"
-
-#define CUTLASS_STATUS_CHECK(status) \
- { \
- TORCH_CHECK(status == cutlass::Status::kSuccess, \
- __func__, " : Got CUTLASS error: ", \
- cutlassGetStatusString(status)); \
- }
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_INT4_MM_CUTLASS)
-// define common params
-using ElementA = cutlass::int4b_t;
-using ElementB = cutlass::int4b_t;
-using ElementAccumulator = int32_t;
-using OpClass = cutlass::arch::OpClassTensorOp;
-using ArchTag = cutlass::arch::Sm80;
-
-// how many elements to load at a time -> load 128-bit = 32 x 4-bit
-constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
-constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
-#endif
-
-// we will do input checks in python. A and B are stored as int8
-torch::Tensor int4_mm_cutlass(torch::Tensor A, torch::Tensor B) {
-#if defined(BUILD_INT4_MM_CUTLASS)
- int M = A.size(0);
- int K = A.size(1) * 2;
- int N = B.size(1);
- torch::Tensor C = torch::empty({M, N}, A.options().dtype(torch::kInt32));
-
- // some configs for int4 mma
- // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
- // using default config. this can be tuned.
- using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
- using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
- // static int const kStages = 3;
- using ElementC = int32_t;
- using Gemm = cutlass::gemm::device::Gemm<
- ElementA, cutlass::layout::RowMajor, // A matrix
- ElementB, cutlass::layout::ColumnMajor, // B matrix
- ElementC, cutlass::layout::RowMajor, // C matrix
- ElementAccumulator, OpClass, ArchTag,
- ThreadblockShape, WarpShape, InstructionShape
- >;
- Gemm::Arguments args {
- {M, N, K},
- {reinterpret_cast<ElementA *>(A.data_ptr<int8_t>()), K},
- {reinterpret_cast<ElementB *>(B.data_ptr<int8_t>()), K},
- {C.data_ptr<ElementC>(), N},
- {C.data_ptr<ElementC>(), N},
- {1, 0} // epilogue
- };
- Gemm gemm_op;
- CUTLASS_STATUS_CHECK(gemm_op(args));
- return C;
-#else
- TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
- return at::Tensor{};
-#endif
-}
-
-template<
- typename ElementC,
- typename ThreadblockShape,
- typename WarpShape,
- typename InstructionShape,
- int numStages>
-void scaled_int4_mm_cutlass_dispatch(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale, torch::Tensor C) {
- // problem shape
- int M = A.size(0);
- int K = A.size(1) * 2;
- int N = B.size(1);
-
- constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // 8 for BF16/FP16
- using ElementEpilogue = float;
- constexpr int numEpilogueStages = 1;
-
- // build epilogue visitor tree
- using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
- ThreadblockShape, WarpShape, ElementC, AlignmentC, numEpilogueStages
- >;
-
- using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
- constexpr auto RoundMode = cutlass::FloatRoundStyle::round_to_nearest;
- using Multiply = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, ElementEpilogue, ElementEpilogue, RoundMode
- >;
-
- // (1, N)
- using ColScale = cutlass::epilogue::threadblock::VisitorRowBroadcast<
- OutputTileThreadMap, ElementC,
- cute::Stride<cute::_0, cute::_1, int32_t> // MNL
- >;
- using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, Accum, ColScale>;
-
- // (M, 1)
- using RowScale = cutlass::epilogue::threadblock::VisitorColBroadcast<
- OutputTileThreadMap, ElementC,
- cute::Stride<cute::_1, cute::_0, int32_t> // MNL
- >;
- using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, EVTCompute0, RowScale>;
-
- using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
- OutputTileThreadMap, ElementC, RoundMode,
- cute::Stride<int64_t, cute::_1, int64_t> // MNL
- >;
- using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<Output, EVTCompute1>;
-
- using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
- ElementA, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, AlignmentA,
- ElementB, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentB,
- ElementC, cutlass::layout::RowMajor, AlignmentC,
- ElementAccumulator, ElementEpilogue, OpClass, ArchTag,
- ThreadblockShape, WarpShape, InstructionShape,
- EVTOutput,
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
- numStages,
- cutlass::arch::OpMultiplyAddSaturate, // OpMultiplyAdd does not work
- numEpilogueStages
- >::GemmKernel;
- using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
-
- // col_scale, row_scale, and C must have the same dtype
- const ElementA *A_ptr = reinterpret_cast<ElementA *>(A.data_ptr<int8_t>());
- const ElementB *B_ptr = reinterpret_cast<ElementB *>(B.data_ptr<int8_t>());
- const ElementC *col_scale_ptr = reinterpret_cast<ElementC *>(col_scale.data_ptr());
- const ElementC *row_scale_ptr = reinterpret_cast<ElementC *>(row_scale.data_ptr());
- ElementC *C_ptr = reinterpret_cast<ElementC *>(C.data_ptr());
-
- typename EVTOutput::Arguments callback_args{
- {
- {
- {}, // Accum
- {col_scale_ptr, ElementC(0), {cute::_0{}, cute::_1{}, int32_t(N)}}, // ColScale
- {} // Multiply
- }, // EVTCompute0
- {row_scale_ptr, ElementC(0), {cute::_1{}, cute::_0{}, int32_t(M)}}, // RowScale
- {} // Multiply
- }, // EVTCompute1
- {C_ptr, {int64_t{N}, cute::_1{}, int64_t{M*N}}} // EVTOutput
- };
-
- typename DeviceGemm::Arguments args(
- cutlass::gemm::GemmUniversalMode::kGemm,
- cutlass::gemm::GemmCoord{M, N, K},
- 1, // batch_split
- callback_args,
- A_ptr, B_ptr, nullptr, nullptr, // unsued C_ptr and D_ptr
- M * K, N * K, 0, 0, // batch_stride A, B, C, D
- K, K, 0, 0 // stride A, B, C, D
- );
-
- DeviceGemm gemm_op;
- auto stream = at::cuda::getCurrentCUDAStream();
- CUTLASS_STATUS_CHECK(gemm_op.can_implement(args));
- CUTLASS_STATUS_CHECK(gemm_op(args, nullptr, stream));
-}
-
-// we will do input checks in python. A and B are stored as int8
-// this function is based on the following cutlass example
-// https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
-// also with the help of emitted code from cutlass Python
-torch::Tensor scaled_int4_mm_cutlass(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale) {
-#if defined(BUILD_INT4_MM_CUTLASS)
- int M = A.size(0);
- int N = B.size(1);
- torch::Tensor C = torch::empty({M, N}, row_scale.options());
-
- // some configs for int4 mma
- // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
- // using default config. this can be tuned.
- using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
- using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
- constexpr int numStages = 3;
-
- AT_DISPATCH_SWITCH(
- row_scale.scalar_type(),
- "scaled_int4_mm_cutlass",
- AT_DISPATCH_CASE(
- torch::ScalarType::Half,
- [&]() {
- using ElementC = cutlass::half_t;
- scaled_int4_mm_cutlass_dispatch<
- ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>(
- A, B, row_scale, col_scale, C);
- }
- )
- AT_DISPATCH_CASE(
- torch::ScalarType::BFloat16,
- [&]() {
- using ElementC = cutlass::bfloat16_t;
- scaled_int4_mm_cutlass_dispatch<
- ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>(
- A, B, row_scale, col_scale, C);
- }
- )
- );
-
- return C;
-#else
- TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
- return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
- m.impl("torchao::int4_mm_cutlass", &int4_mm_cutlass);
- m.impl("torchao::scaled_int4_mm_cutlass", &scaled_int4_mm_cutlass);
-}
-
-} // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh
new file mode 100644
index 0000000..d1969bc
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh
@@ -0,0 +1,578 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/core/Tensor.h>
+#include <ATen/cuda/CUDAUtils.h>
+#include <c10/util/Exception.h>
+
+#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
+ defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
+#define BUILD_ROWWISE_SCALED_LINEAR_CUTLASS
+#endif
+
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+#include <cuda_runtime.h>
+#include <cutlass/cutlass.h>
+#include <cutlass/gemm/device/gemm_universal.h>
+#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
+#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
+#include <cutlass/gemm/device/gemm_universal_adapter.h>
+
+#define CUTLASS_STATUS_CHECK(status) \
+ { \
+ TORCH_CHECK(status == cutlass::Status::kSuccess, \
+ __func__, " : Got CUTLASS error: ", \
+ cutlassGetStatusString(status)); \
+ }
+#endif
+
+namespace torchao {
+
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+template<
+ typename ThreadblockShape,
+ typename WarpShape,
+ typename InstructionShape,
+ typename ThreadblockSwizzle,
+ int NumStages,
+ typename ElementA,
+ typename ElementB,
+ typename ElementOutput,
+ typename ElementC,
+ typename UseTensorC,
+ typename ElementAScale,
+ typename ElementBScale>
+void rowwise_scaled_linear_kernel_cutlass_sm8x(
+ const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+ const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+ const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+ static_assert((cutlass::sizeof_bits<ElementA>::value >= 8 ||
+ 8 % cutlass::sizeof_bits<ElementA>::value == 0) &&
+ (cutlass::sizeof_bits<ElementB>::value >= 8 ||
+ 8 % cutlass::sizeof_bits<ElementB>::value == 0));
+
+ using SmArch = cutlass::arch::Sm80;
+
+ using LayoutA = cutlass::layout::RowMajor;
+ using LayoutB = cutlass::layout::ColumnMajor;
+ using LayoutOutput = cutlass::layout::RowMajor;
+
+ using ElementAccumulator = int32_t;
+ using Operator =
+ std::conditional_t<std::is_same<ElementA, ElementB>::value,
+ cutlass::arch::OpMultiplyAddSaturate,
+ cutlass::arch::OpMultiplyAddMixedInputUpcast>;
+
+ using ElementEpilogue = float;
+
+ constexpr auto NumEVTEpilogueStages = 1;
+
+ const int m = tensor_a.size(0);
+ const int n = tensor_b.size(0);
+ int k = tensor_a.size(1);
+ if constexpr (cutlass::sizeof_bits<ElementA>::value < 8) {
+ k *= 8 % cutlass::sizeof_bits<ElementA>::value;
+ }
+
+ constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
+ constexpr int AlignmentAScale =
+ 128 / cutlass::sizeof_bits<ElementAScale>::value;
+ constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
+ constexpr int AlignmentBScale =
+ 128 / cutlass::sizeof_bits<ElementBScale>::value;
+ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
+ constexpr int AlignmentOutput =
+ 128 / cutlass::sizeof_bits<ElementOutput>::value;
+
+ // Check for current CUTLASS limitations w.r.t. alignments.
+ TORCH_CHECK(k % AlignmentA == 0,
+ __func__, " : Number of columns of tensor A must be divisible ",
+ "by ", AlignmentA);
+ TORCH_CHECK(k % AlignmentB == 0,
+ __func__, " : Number of columns of tensor B must be divisible ",
+ "by ", AlignmentB);
+ TORCH_CHECK(n % AlignmentC == 0,
+ __func__, " : Number of columns of tensor C must be divisible ",
+ "by ", AlignmentC);
+
+ using TensorAScaleTileThreadMap =
+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
+ ThreadblockShape,
+ WarpShape,
+ ElementAScale,
+ AlignmentAScale,
+ NumEVTEpilogueStages>;
+ using TensorBScaleTileThreadMap =
+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
+ ThreadblockShape,
+ WarpShape,
+ ElementBScale,
+ AlignmentBScale,
+ NumEVTEpilogueStages>;
+ using TensorCTileThreadMap =
+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
+ ThreadblockShape,
+ WarpShape,
+ ElementC,
+ AlignmentC,
+ NumEVTEpilogueStages>;
+ using OutputTileThreadMap =
+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
+ ThreadblockShape,
+ WarpShape,
+ ElementOutput,
+ AlignmentOutput,
+ NumEVTEpilogueStages>;
+
+ using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
+
+ using TensorAScale =
+ cutlass::epilogue::threadblock::VisitorColBroadcast<
+ TensorAScaleTileThreadMap,
+ ElementAScale,
+ cute::Stride<cute::_1, cute::_0, int64_t>>;
+ using TensorAScaleArguments = typename TensorAScale::Arguments;
+
+ using TensorBScale =
+ cutlass::epilogue::threadblock::VisitorRowBroadcast<
+ TensorBScaleTileThreadMap,
+ ElementBScale,
+ cute::Stride<cute::_0, cute::_1, int64_t>>;
+ using TensorBScaleArguments = typename TensorBScale::Arguments;
+
+ using TensorCScalar =
+ cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
+ using TensorCTensor =
+ cutlass::epilogue::threadblock::VisitorRowBroadcast<
+ TensorCTileThreadMap,
+ ElementC,
+ cute::Stride<cute::_0, cute::_1, int64_t>>;
+ using TensorC =
+ std::conditional_t<UseTensorC::value, TensorCTensor, TensorCScalar>;
+ using TensorCArguments = typename TensorC::Arguments;
+
+ using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, ElementEpilogue, ElementEpilogue,
+ cutlass::FloatRoundStyle::round_to_nearest
+ >;
+ using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT<
+ ApplyAScale,
+ Accum,
+ TensorAScale>;
+
+ using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, ElementEpilogue, ElementEpilogue,
+ cutlass::FloatRoundStyle::round_to_nearest
+ >;
+ using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT<
+ ApplyBScale,
+ EVTApplyAScale,
+ TensorBScale>;
+
+ using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::plus, ElementEpilogue, ElementEpilogue,
+ cutlass::FloatRoundStyle::round_to_nearest
+ >;
+ using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
+ ApplySum,
+ EVTApplyBScale,
+ TensorC>;
+
+ using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
+ OutputTileThreadMap, ElementOutput,
+ cutlass::FloatRoundStyle::round_to_nearest,
+ cute::Stride<int64_t, cute::_1, int64_t> // StrideMNL
+ >;
+
+ using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
+ Output,
+ EVTApplySum>;
+
+ using EVTKernel =
+ typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
+ ElementOutput, LayoutOutput, AlignmentOutput,
+ ElementAccumulator,
+ ElementEpilogue,
+ cutlass::arch::OpClassTensorOp,
+ SmArch,
+ ThreadblockShape,
+ WarpShape,
+ InstructionShape,
+ EVTOutput,
+ ThreadblockSwizzle,
+ NumStages,
+ Operator,
+ NumEVTEpilogueStages
+ >::GemmKernel;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
+
+ cutlass::gemm::GemmCoord problem_size(m, n, k);
+ constexpr auto SplitKFactor = 1;
+
+ TensorAScaleArguments tensor_a_scale_arguments{
+ (ElementAScale*)tensor_a_scale.data_ptr(),
+ ElementAScale(1),
+ {cute::_1{}, cute::_0{}, problem_size.m()}
+ };
+ TensorBScaleArguments tensor_b_scale_arguments{
+ (ElementBScale*)tensor_b_scale.data_ptr(),
+ ElementBScale(1),
+ {cute::_0{}, cute::_1{}, problem_size.n()}
+ };
+ TensorCArguments tensor_c_arguments{
+ [&]() -> TensorCArguments {
+ if constexpr (UseTensorC::value) {
+ return {(ElementC*)tensor_c.data_ptr(),
+ ElementC(0),
+ {cute::_0{}, cute::_1{}, problem_size.n()}};
+ } else {
+ return {ElementC(0)};
+ }
+ }()
+ };
+ typename Output::Arguments output_arguments{
+ (ElementOutput*)tensor_d.data_ptr(),
+ {problem_size.n(), cute::_1{}, problem_size.mn().product()}
+ };
+ typename EVTOutput::Arguments callback_arguments{
+ {
+ {
+ {
+ {}, // Accum
+ tensor_a_scale_arguments, // TensorAScale
+ {} // ApplyAScale
+ }, // EVTApplyAScale
+ tensor_b_scale_arguments, // TensorBScale
+ {}, // ApplyBScale
+ }, // EVTApplyBScale
+ tensor_c_arguments, // TensorC
+ {} // ApplySum
+ }, // EVTApplySum
+ output_arguments // Output
+ }; // EVTOutput
+
+ typename Gemm::Arguments arguments(
+ cutlass::gemm::GemmUniversalMode::kGemm,
+ problem_size,
+ SplitKFactor,
+ callback_arguments, // arguments of EVT callbacks
+ (ElementA*)tensor_a.data_ptr(),
+ (ElementB*)tensor_b.data_ptr(),
+ nullptr, // ptr C (unused)
+ nullptr, // ptr D (unused)
+ problem_size.mk().product(), // batch stride A
+ problem_size.nk().product(), // batch stride B
+ 0, // batch stride C (unused)
+ 0, // batch stride D (unused)
+ problem_size.k(), // stride A
+ problem_size.k(), // stride B
+ 0, // stride C (unused)
+ 0 // stride D (unused)
+ );
+
+ Gemm gemm_op;
+
+ cutlass::Status status;
+
+ // Verify that GEMM operation with given arguments can be performed
+ // by CUTLASS.
+ status = gemm_op.can_implement(arguments);
+ CUTLASS_STATUS_CHECK(status);
+
+ // Allocate workspace for CUTLASS mixed datatypes GEMM kernel.
+ const auto workspace_size = Gemm::get_workspace_size(arguments);
+ auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
+ at::TensorOptions().dtype(at::kByte));
+
+ // Initialize CUTLASS mixed datatypes GEMM object.
+ status = gemm_op.initialize(arguments, workspace.data_ptr(),
+ at::cuda::getCurrentCUDAStream());
+ CUTLASS_STATUS_CHECK(status);
+
+ // Perform mixed datatypes GEMM operation.
+ status = gemm_op.run(at::cuda::getCurrentCUDAStream());
+ CUTLASS_STATUS_CHECK(status);
+
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template<typename ElementA, typename ElementB, typename... Types>
+static void select_config(
+ const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+ const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+ const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+ const auto dprops = at::cuda::getCurrentDeviceProperties();
+ const auto is_sm8x = dprops->major == 8;
+
+ if (is_sm8x) {
+ if constexpr (std::is_same<ElementA, cutlass::int4b_t>::value &&
+ std::is_same<ElementB, cutlass::int4b_t>::value) {
+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
+ using ThreadblockSwizzle =
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
+ constexpr auto NumStages = 3;
+ rowwise_scaled_linear_kernel_cutlass_sm8x<
+ ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+ NumStages, ElementA, ElementB, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ return;
+ } else if constexpr (std::is_same<ElementA, int8_t>::value &&
+ std::is_same<ElementB, cutlass::int4b_t>::value) {
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
+ using ThreadblockSwizzle =
+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
+
+ // A minimal heuristic to improve performance for small number
+ // of inputs cases.
+ if (tensor_a.size(0) <= 16) {
+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>;
+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>;
+ constexpr auto NumStages = 6;
+ rowwise_scaled_linear_kernel_cutlass_sm8x<
+ ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+ NumStages, ElementA, ElementB, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ } else if (tensor_a.size(0) <= 32) {
+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
+ constexpr auto NumStages = 5;
+ rowwise_scaled_linear_kernel_cutlass_sm8x<
+ ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+ NumStages, ElementA, ElementB, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ } else {
+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>;
+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>;
+ constexpr auto NumStages = 4;
+ rowwise_scaled_linear_kernel_cutlass_sm8x<
+ ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+ NumStages, ElementA, ElementB, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ }
+ return;
+ }
+ }
+
+ TORCH_CHECK(false,
+ __func__, " : Operator not supported on SM", dprops->major, ".",
+ dprops->minor, " for given operands");
+}
+
+template<
+ typename ElementA,
+ typename ElementB,
+ typename ElementOutput,
+ typename... Types>
+static void
+dispatch_on_tensor_c(
+ const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+ const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+ const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+ if (tensor_c.numel() == 0) {
+ using ElementC = ElementOutput;
+ using UseTensorC = std::false_type;
+ select_config<
+ ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ return;
+ }
+
+ using UseTensorC = std::true_type;
+ if (tensor_c.scalar_type() == at::ScalarType::Half) {
+ using ElementC = cutlass::half_t;
+ select_config<
+ ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ return;
+ } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
+ using ElementC = cutlass::bfloat16_t;
+ select_config<
+ ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ return;
+ }
+
+ TORCH_CHECK(false,
+ __func__, " : Operator not supported for datatype ",
+ tensor_c.scalar_type(), " for addend");
+}
+
+template<typename ElementA, typename ElementB>
+static void
+dispatch_on_tensor_a_scale_and_tensor_b_scale(
+ const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+ const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+ const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+ TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
+ __func__, " : Operator not supported for output datatype ",
+ tensor_d.scalar_type(), " as it's different from the first ",
+ " operand scale datatype ", tensor_a_scale.scalar_type());
+
+ if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
+ tensor_b_scale.scalar_type() == at::ScalarType::Half) {
+ using ElementAScale = cutlass::half_t;
+ using ElementBScale = cutlass::half_t;
+ using ElementOutput = cutlass::half_t;
+ dispatch_on_tensor_c<ElementA, ElementB, ElementOutput, ElementAScale,
+ ElementBScale>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
+ return;
+ } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
+ tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
+ using ElementAScale = cutlass::bfloat16_t;
+ using ElementBScale = cutlass::bfloat16_t;
+ using ElementOutput = cutlass::bfloat16_t;
+ dispatch_on_tensor_c<ElementA, ElementB, ElementOutput, ElementAScale,
+ ElementBScale>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
+ return;
+ }
+
+ TORCH_CHECK(false,
+ __func__, " : Operator not supported for combination of data ",
+ "types ", tensor_a_scale.scalar_type(),
+ " for first operand scale and ", tensor_b_scale.scalar_type(),
+ " for second operand scale");
+}
+
+template<typename ElementA, typename ElementB>
+void
+rowwise_scaled_linear_cutlass_check_inputs(
+ const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+ const at::Tensor& w_scale, const at::Tensor& bias) {
+ // Validate layouts of arguments.
+ TORCH_CHECK(xq.dim() >= 2,
+ __func__, " : Expected xq argument to be 2D or "
+ "higher-dimensional tensor, got ", xq.dim(), " dims");
+ TORCH_CHECK(xq.layout() == at::Layout::Strided,
+ __func__, " : Expected xq argument to be strided, got layout ",
+ xq.layout());
+ TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
+ __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
+ "D tensor, got ", x_scale.dim(), " dims");
+ TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
+ __func__, " : Expected xq scale argument to be strided, got "
+ "layout ", x_scale.layout());
+ TORCH_CHECK(wq.dim() == 2,
+ __func__, " : Expected wq argument to be 2D tensor, got ",
+ wq.dim(), " dims");
+ TORCH_CHECK(wq.layout() == at::Layout::Strided,
+ __func__, " : Expected wq argument to be strided, got layout ",
+ wq.layout());
+ TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
+ __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
+ "got ", w_scale.dim(), " dims");
+ TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
+ __func__, " : Expected wq scale argument to be strided, got "
+ "layout ", w_scale.layout());
+ if (bias.numel() > 0) {
+ TORCH_CHECK(bias.dim() == 1,
+ __func__, " : Expected bias argument to be 1D tensor, got ",
+ bias.dim(), " dims");
+ TORCH_CHECK(bias.layout() == at::Layout::Strided,
+ __func__, " : Expected bias argument to be strided, got ",
+ "layout ", bias.layout());
+ }
+
+ // Validate sizes of arguments.
+ const auto xq_sizes = xq.sizes().vec();
+ TORCH_CHECK(xq_sizes.back() == wq.size(1) ||
+ xq_sizes.back() == 2 * wq.size(1),
+ __func__, " : Expected xq argument to have ", wq.size(1), " or ",
+ 2 * wq.size(1), " columns, but got ", xq_sizes.back());
+ const auto x_scale_sizes = x_scale.sizes().vec();
+ for (auto i = 0; i < x_scale_sizes.size(); ++i)
+ TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
+ __func__, " : Expected xq scale argument size at position ",
+ i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
+ TORCH_CHECK(w_scale.numel() == wq.size(0),
+ __func__, " : Expected wq scale argument to have ", wq.size(0),
+ " elements, got ", w_scale.numel(), " elements");
+ if (bias.numel() > 0) {
+ TORCH_CHECK(bias.numel() == wq.size(0),
+ __func__, " : Expected bias argument to have ", wq.size(0),
+ " elements, got ", bias.numel(), " elements");
+ }
+
+ // Validate strides of arguments.
+ const auto xq_strides = xq.strides();
+ TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
+ __func__, " : Expected xq argument in row-major layout");
+ auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
+ for (int i = xq_strides.size() - 3; i >= 0; --i) {
+ xq_stride_expected *= xq_sizes[i + 1];
+ TORCH_CHECK(xq_strides[i] == xq_stride_expected,
+ __func__, " : Expected xq argument in row-major layout");
+ }
+ TORCH_CHECK(x_scale.is_contiguous(),
+ __func__, " : Expected xq scale argument to be contiguous");
+ const auto wq_strides = wq.strides();
+ TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
+ __func__, " : Expected wq argument in row-major layout");
+ TORCH_CHECK(w_scale.is_contiguous(),
+ __func__, " : Expected wq scale argument to be contiguous");
+ if (bias.numel() > 0) {
+ const auto bias_strides = bias.strides();
+ TORCH_CHECK(bias_strides[0] == 1,
+ __func__, " : Expected bias argument to be contiguous");
+ }
+}
+#endif
+
+// Perform linear operation, using corresponding CUTLASS datatypes
+// GEMM kernel, to given arguments - result produced is:
+// (tensor_a * tensor_a_scale) @ (tensor_b * tensor_b_scale).T + tensor_c
+//
+// Notes: The "tensor_a" and "tensor_b" are expected to be 2D tensors.
+// The "tensor_a_scale" tensor is expected to be a vector, of size
+// equal to number of rows of "tensor_a" tensor. The "tensor_b_scale"
+// tensor is expected to be a vector, of size equal to number of rows
+// of "tensor_b" tensor. The "tensor_c" tensor is expected to be a
+// vector, of size equal to number of rows of "tensor_b" tensor.
+template <typename ElementA, typename ElementB>
+at::Tensor
+rowwise_scaled_linear_cutlass(
+ const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+ const at::Tensor& w_scale, const at::Tensor& bias) {
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+ // Check inputs.
+ rowwise_scaled_linear_cutlass_check_inputs<ElementA, ElementB>(
+ xq, x_scale, wq, w_scale, bias);
+
+ // Squash the input tensors as appropriate.
+ const auto xq_sizes = xq.sizes().vec();
+ const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
+ const auto x_scale_1d = x_scale.reshape({-1});
+ const auto w_scale_1d = w_scale.reshape({-1});
+
+ // Create result tensor.
+ at::Tensor result =
+ x_scale.new_empty({xq_2d.size(0), wq.size(0)});
+
+ // Dispatch to appropriate kernel template.
+ dispatch_on_tensor_a_scale_and_tensor_b_scale<ElementA, ElementB>(
+ xq_2d, x_scale_1d, wq, w_scale_1d, bias, result);
+
+ // Reshape and return result tensor.
+ auto result_sizes = xq_sizes;
+ result_sizes.back() = wq.size(0);
+ return result.reshape(result_sizes);
+#else
+ TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
+ return at::Tensor{};
+#endif
+}
+
+} // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
new file mode 100644
index 0000000..9a64b2b
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
@@ -0,0 +1,28 @@
+#include <torch/extension.h>
+
+#include "rowwise_scaled_linear_cutlass.cuh"
+
+namespace torchao {
+
+at::Tensor
+rowwise_scaled_linear_cutlass_s4s4(
+ const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+ const at::Tensor& w_scale, const at::Tensor& bias) {
+ // Validate input datatypes.
+ TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar,
+ __func__, " : The input datatypes combination ", xq.dtype(),
+ " for xq and ", wq.dtype(), " for wq is not supported");
+
+ // Dispatch to appropriate kernel template.
+ using ElementA = cutlass::int4b_t;
+ using ElementB = cutlass::int4b_t;
+ return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
+ xq, x_scale, wq, w_scale, bias);
+}
+
+TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
+ m.impl("torchao::rowwise_scaled_linear_cutlass_s4s4",
+ &rowwise_scaled_linear_cutlass_s4s4);
+}
+
+} // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
new file mode 100644
index 0000000..752c557
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
@@ -0,0 +1,28 @@
+#include <torch/extension.h>
+
+#include "rowwise_scaled_linear_cutlass.cuh"
+
+namespace torchao {
+
+at::Tensor
+rowwise_scaled_linear_cutlass_s8s4(
+ const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+ const at::Tensor& w_scale, const at::Tensor& bias) {
+ // Validate input datatypes.
+ TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar,
+ __func__, " : The input datatypes combination ", xq.dtype(),
+ " for xq and ", wq.dtype(), " for wq is not supported");
+
+ // Dispatch to appropriate kernel template.
+ using ElementA = int8_t;
+ using ElementB = cutlass::int4b_t;
+ return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
+ xq, x_scale, wq, w_scale, bias);
+}
+
+TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
+ m.impl("torchao::rowwise_scaled_linear_cutlass_s8s4",
+ &rowwise_scaled_linear_cutlass_s8s4);
+}
+
+} // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu
deleted file mode 100644
index 8faf13c..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu
+++ /dev/null
@@ -1,268 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
- defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_S4S4_LINEAR_CUTLASS
-#endif
-
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
-#include "scaled_linear.h"
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
-
-template<typename... Types>
-static void select_config(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- const auto dprops = at::cuda::getCurrentDeviceProperties();
- const auto is_sm8x = dprops->major >= 8;
-
- if (is_sm8x) {
- using ElementA = cutlass::int4b_t;
- using ElementB = cutlass::int4b_t;
- using ElementAccumulator = int32_t;
-
- using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
- using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
- constexpr auto NumStages = 3;
- using Operator = cutlass::arch::OpMultiplyAddSaturate;
- // using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; // this does not work
- using ThreadblockSwizzle =
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
-
- scaled_linear_kernel_cutlass_sm8x<
- ThreadblockShape, WarpShape, InstructionShape, NumStages,
- ThreadblockSwizzle, ElementA, ElementB, ElementAccumulator, Operator,
- Types...>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported on SM", dprops->major, ".",
- dprops->minor, " for given operands");
-}
-
-template<typename ElementAScale, typename ElementBScale, typename ElementOutput>
-static void
-dispatch_on_tensor_c(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- if (tensor_c.numel() == 0) {
- using ElementC = ElementOutput;
- using UseTensorC = std::false_type;
- select_config<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- }
-
- using UseTensorC = std::true_type;
- if (tensor_c.scalar_type() == at::ScalarType::Half) {
- using ElementC = cutlass::half_t;
- select_config<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
- using ElementC = cutlass::bfloat16_t;
- select_config<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported for datatype ",
- tensor_c.scalar_type(), " for addend");
-}
-
-static void
-dispatch_on_tensor_a_scale_and_tensor_b_scale(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
- __func__, " : Operator not supported for output datatype ",
- tensor_d.scalar_type(), " as it's different from the first ",
- " operand scale datatype ", tensor_a_scale.scalar_type());
-
- if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
- tensor_b_scale.scalar_type() == at::ScalarType::Half) {
- using ElementAScale = cutlass::half_t;
- using ElementBScale = cutlass::half_t;
- using ElementOutput = cutlass::half_t;
- dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
- return;
- } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
- tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
- using ElementAScale = cutlass::bfloat16_t;
- using ElementBScale = cutlass::bfloat16_t;
- using ElementOutput = cutlass::bfloat16_t;
- dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
- return;
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported for combination of data ",
- "types ", tensor_a_scale.scalar_type(),
- " for first operand scale and ", tensor_b_scale.scalar_type(),
- " for second operand scale");
-}
-
-static void
-check_inputs(
- const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
- const at::Tensor& w_scale, const at::Tensor& bias) {
- // Validate layouts of arguments.
- TORCH_CHECK(xq.dim() >= 2,
- __func__, " : Expected xq argument to be 2D or "
- "higher-dimensional tensor, got ", xq.dim(), " dims");
- TORCH_CHECK(xq.layout() == at::Layout::Strided,
- __func__, " : Expected xq argument to be strided, got layout ",
- xq.layout());
- TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
- __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
- "D tensor, got ", x_scale.dim(), " dims");
- TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
- __func__, " : Expected xq scale argument to be strided, got "
- "layout ", x_scale.layout());
- TORCH_CHECK(wq.dim() == 2,
- __func__, " : Expected wq argument to be 2D tensor, got ",
- wq.dim(), " dims");
- TORCH_CHECK(wq.layout() == at::Layout::Strided,
- __func__, " : Expected wq argument to be strided, got layout ",
- wq.layout());
- TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
- __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
- "got ", w_scale.dim(), " dims");
- TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
- __func__, " : Expected wq scale argument to be strided, got "
- "layout ", w_scale.layout());
- if (bias.numel() > 0) {
- TORCH_CHECK(bias.dim() == 1,
- __func__, " : Expected bias argument to be 1D tensor, got ",
- bias.dim(), " dims");
- TORCH_CHECK(bias.layout() == at::Layout::Strided,
- __func__, " : Expected bias argument to be strided, got ",
- "layout ", bias.layout());
- }
-
- // Validate sizes of arguments.
- const auto xq_sizes = xq.sizes().vec();
- TORCH_CHECK(xq_sizes.back() == wq.size(1),
- __func__, " : Expected xq argument to have ", wq.size(1),
- " columns, but got ", xq_sizes.back());
- const auto x_scale_sizes = x_scale.sizes().vec();
- for (auto i = 0; i < x_scale_sizes.size(); ++i)
- TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
- __func__, " : Expected xq scale argument size at position ",
- i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
- TORCH_CHECK(w_scale.numel() == wq.size(0),
- __func__, " : Expected wq scale argument to have ", wq.size(0),
- " elements, got ", w_scale.numel(), " elements");
- if (bias.numel() > 0) {
- TORCH_CHECK(bias.numel() == wq.size(0),
- __func__, " : Expected bias argument to have ", wq.size(0),
- " elements, got ", bias.numel(), " elements");
- }
-
- // Validate strides of arguments.
- const auto xq_strides = xq.strides();
- TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
- __func__, " : Expected xq argument in row-major layout");
- auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
- for (int i = xq_strides.size() - 3; i >= 0; --i) {
- xq_stride_expected *= xq_sizes[i + 1];
- TORCH_CHECK(xq_strides[i] == xq_stride_expected,
- __func__, " : Expected xq argument in row-major layout");
- }
- TORCH_CHECK(x_scale.is_contiguous(),
- __func__, " : Expected xq scale argument to be contiguous");
- const auto wq_strides = wq.strides();
- TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
- __func__, " : Expected wq argument in row-major layout");
- TORCH_CHECK(w_scale.is_contiguous(),
- __func__, " : Expected wq scale argument to be contiguous");
- if (bias.numel() > 0) {
- const auto bias_strides = bias.strides();
- TORCH_CHECK(bias_strides[0] == 1,
- __func__, " : Expected bias argument to be contiguous");
- }
-}
-#endif
-
-// Perform linear operation, using corresponding CUTLASS mixed
-// data-types GEMM kernel, to given arguments:
-// result = (xq * x_scale) @ (wq * w_scale).T + bias
-// Notes: The "x_scale" tensor is expected to be a vector, of size
-// equal to number of rows of "xq" tensor. The "w_scale" tensor is
-// expected to be a vector, of size equal to number of rows of "wq"
-// tensor. The "bias" tensor is expected to be a vector, of size equal
-// to number of rows of "wq" tensor.
-at::Tensor
-s4s4_linear_cutlass(
- const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
- const at::Tensor& w_scale, const at::Tensor& bias) {
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
- // Check inputs.
- check_inputs(xq, x_scale, wq, w_scale, bias);
-
- // Squash the input tensors as appropriate.
- const auto xq_sizes = xq.sizes().vec();
- const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
- const auto x_scale_sizes = x_scale.sizes().vec();
- const auto x_scale_1d = x_scale.reshape({-1});
- const auto w_scale_1d = w_scale.reshape({-1});
-
- // Introduce alias names for arguments, according to the CUTLASS
- // naming conventions.
- const auto& tensor_a = xq_2d;
- const auto& tensor_a_scale = x_scale_1d;
- const auto& tensor_b = wq;
- const auto& tensor_b_scale = w_scale_1d;
- const auto& tensor_c = bias;
-
- // Create output tensor.
- at::Tensor tensor_d =
- tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)});
-
- // Dispatch to appropriate kernel template.
- dispatch_on_tensor_a_scale_and_tensor_b_scale(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-
- // Reshape and return output tensor.
- auto tensor_d_sizes = xq_sizes;
- tensor_d_sizes.back() = wq.size(0);
- return tensor_d.reshape(tensor_d_sizes);
-#else
- TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
- return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
- m.impl("torchao::s4s4_linear_cutlass", &s4s4_linear_cutlass);
-}
-
-} // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
deleted file mode 100644
index 53eaf53..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
+++ /dev/null
@@ -1,315 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
- defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_S8S4_LINEAR_CUTLASS
-#endif
-
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
-#include "scaled_linear.h"
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
-
-template<typename ElementA, typename ElementB, typename... Types>
-static void select_config(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- const auto dprops = at::cuda::getCurrentDeviceProperties();
- const auto is_sm8x = dprops->major == 8;
-
- if (is_sm8x) {
- if constexpr (std::is_same<ElementA, int8_t>::value &&
- std::is_same<ElementB, cutlass::int4b_t>::value) {
- using ThreadblockSwizzle =
- cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
-
- // A minimal heuristic to improve performance for small number
- // of inputs cases.
- if (tensor_a.size(0) <= 16) {
- using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>;
- using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>;
- constexpr auto NumStages = 6;
- scaled_linear_kernel_cutlass_sm8x<
- ThreadblockShape, WarpShape, InstructionShape, NumStages,
- ThreadblockSwizzle, ElementA, ElementB, Types...>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- } else if (tensor_a.size(0) <= 32) {
- using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
- using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
- constexpr auto NumStages = 5;
- scaled_linear_kernel_cutlass_sm8x<
- ThreadblockShape, WarpShape, InstructionShape, NumStages,
- ThreadblockSwizzle, ElementA, ElementB, Types...>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- } else {
- using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>;
- using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>;
- constexpr auto NumStages = 4;
- scaled_linear_kernel_cutlass_sm8x<
- ThreadblockShape, WarpShape, InstructionShape, NumStages,
- ThreadblockSwizzle, ElementA, ElementB, Types...>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- }
- return;
- }
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported on SM", dprops->major, ".",
- dprops->minor, " for given operands");
-}
-
-template<typename... Types>
-static void
-dispatch_on_tensor_a_and_tensor_b(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- if (tensor_a.scalar_type() == at::ScalarType::Char) {
- if (tensor_b.scalar_type() == at::ScalarType::Char) {
- if (tensor_a.size(1) == 2 * tensor_b.size(1)) {
- using ElementA = int8_t;
- using ElementB = cutlass::int4b_t;
- using ElementAccumulator = int32_t;
- using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast;
- select_config<
- ElementA, ElementB, ElementAccumulator, Operator, Types...>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- }
- return;
- }
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported for combination of data ",
- "types ", tensor_a.scalar_type(), " for first operand and ",
- tensor_b.scalar_type(), " for second operand");
-}
-
-
-template<typename ElementAScale, typename ElementBScale, typename ElementOutput>
-static void
-dispatch_on_tensor_c(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- if (tensor_c.numel() == 0) {
- using ElementC = ElementOutput;
- using UseTensorC = std::false_type;
- dispatch_on_tensor_a_and_tensor_b<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- }
-
- using UseTensorC = std::true_type;
- if (tensor_c.scalar_type() == at::ScalarType::Half) {
- using ElementC = cutlass::half_t;
- dispatch_on_tensor_a_and_tensor_b<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
- using ElementC = cutlass::bfloat16_t;
- dispatch_on_tensor_a_and_tensor_b<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported for datatype ",
- tensor_c.scalar_type(), " for addend");
-}
-
-static void
-dispatch_on_tensor_a_scale_and_tensor_b_scale(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
- __func__, " : Operator not supported for output datatype ",
- tensor_d.scalar_type(), " as it's different from the first ",
- " operand scale datatype ", tensor_a_scale.scalar_type());
-
- if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
- tensor_b_scale.scalar_type() == at::ScalarType::Half) {
- using ElementAScale = cutlass::half_t;
- using ElementBScale = cutlass::half_t;
- using ElementOutput = cutlass::half_t;
- dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
- return;
- } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
- tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
- using ElementAScale = cutlass::bfloat16_t;
- using ElementBScale = cutlass::bfloat16_t;
- using ElementOutput = cutlass::bfloat16_t;
- dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
- return;
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported for combination of data ",
- "types ", tensor_a_scale.scalar_type(),
- " for first operand scale and ", tensor_b_scale.scalar_type(),
- " for second operand scale");
-}
-
-static void
-check_inputs(
- const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
- const at::Tensor& w_scale, const at::Tensor& bias) {
- // Validate layouts of arguments.
- TORCH_CHECK(xq.dim() >= 2,
- __func__, " : Expected xq argument to be 2D or "
- "higher-dimensional tensor, got ", xq.dim(), " dims");
- TORCH_CHECK(xq.layout() == at::Layout::Strided,
- __func__, " : Expected xq argument to be strided, got layout ",
- xq.layout());
- TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
- __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
- "D tensor, got ", x_scale.dim(), " dims");
- TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
- __func__, " : Expected xq scale argument to be strided, got "
- "layout ", x_scale.layout());
- TORCH_CHECK(wq.dim() == 2,
- __func__, " : Expected wq argument to be 2D tensor, got ",
- wq.dim(), " dims");
- TORCH_CHECK(wq.layout() == at::Layout::Strided,
- __func__, " : Expected wq argument to be strided, got layout ",
- wq.layout());
- TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
- __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
- "got ", w_scale.dim(), " dims");
- TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
- __func__, " : Expected wq scale argument to be strided, got "
- "layout ", w_scale.layout());
- if (bias.numel() > 0) {
- TORCH_CHECK(bias.dim() == 1,
- __func__, " : Expected bias argument to be 1D tensor, got ",
- bias.dim(), " dims");
- TORCH_CHECK(bias.layout() == at::Layout::Strided,
- __func__, " : Expected bias argument to be strided, got ",
- "layout ", bias.layout());
- }
-
- // Validate sizes of arguments.
- const auto xq_sizes = xq.sizes().vec();
- TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1),
- __func__, " : Expected xq argument to have ", 2 * wq.size(1),
- " columns, but got ", xq_sizes.back());
- const auto x_scale_sizes = x_scale.sizes().vec();
- for (auto i = 0; i < x_scale_sizes.size(); ++i)
- TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
- __func__, " : Expected xq scale argument size at position ",
- i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
- TORCH_CHECK(w_scale.numel() == wq.size(0),
- __func__, " : Expected wq scale argument to have ", wq.size(0),
- " elements, got ", w_scale.numel(), " elements");
- if (bias.numel() > 0) {
- TORCH_CHECK(bias.numel() == wq.size(0),
- __func__, " : Expected bias argument to have ", wq.size(0),
- " elements, got ", bias.numel(), " elements");
- }
-
- // Validate strides of arguments.
- const auto xq_strides = xq.strides();
- TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
- __func__, " : Expected xq argument in row-major layout");
- auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
- for (int i = xq_strides.size() - 3; i >= 0; --i) {
- xq_stride_expected *= xq_sizes[i + 1];
- TORCH_CHECK(xq_strides[i] == xq_stride_expected,
- __func__, " : Expected xq argument in row-major layout");
- }
- TORCH_CHECK(x_scale.is_contiguous(),
- __func__, " : Expected xq scale argument to be contiguous");
- const auto wq_strides = wq.strides();
- TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
- __func__, " : Expected wq argument in row-major layout");
- TORCH_CHECK(w_scale.is_contiguous(),
- __func__, " : Expected wq scale argument to be contiguous");
- if (bias.numel() > 0) {
- const auto bias_strides = bias.strides();
- TORCH_CHECK(bias_strides[0] == 1,
- __func__, " : Expected bias argument to be contiguous");
- }
-}
-#endif
-
-// Perform linear operation, using corresponding CUTLASS mixed
-// data-types GEMM kernel, to given arguments:
-// result = (xq * x_scale) @ (wq * w_scale).T + bias
-// Notes: The "x_scale" tensor is expected to be a vector, of size
-// equal to number of rows of "xq" tensor. The "w_scale" tensor is
-// expected to be a vector, of size equal to number of rows of "wq"
-// tensor. The "bias" tensor is expected to be a vector, of size equal
-// to number of rows of "wq" tensor.
-at::Tensor
-s8s4_linear_cutlass(
- const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
- const at::Tensor& w_scale, const at::Tensor& bias) {
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
- // Check inputs.
- check_inputs(xq, x_scale, wq, w_scale, bias);
-
- // Squash the input tensors as appropriate.
- const auto xq_sizes = xq.sizes().vec();
- const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
- const auto x_scale_sizes = x_scale.sizes().vec();
- const auto x_scale_1d = x_scale.reshape({-1});
- const auto w_scale_1d = w_scale.reshape({-1});
-
- // Introduce alias names for arguments, according to the CUTLASS
- // naming conventions.
- const auto& tensor_a = xq_2d;
- const auto& tensor_a_scale = x_scale_1d;
- const auto& tensor_b = wq;
- const auto& tensor_b_scale = w_scale_1d;
- const auto& tensor_c = bias;
-
- // Create output tensor.
- at::Tensor tensor_d =
- tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)});
-
- // Dispatch to appropriate kernel template.
- dispatch_on_tensor_a_scale_and_tensor_b_scale(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-
- // Reshape and return output tensor.
- auto tensor_d_sizes = xq_sizes;
- tensor_d_sizes.back() = wq.size(0);
- return tensor_d.reshape(tensor_d_sizes);
-#else
- TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
- return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
- m.impl("torchao::s8s4_linear_cutlass", &s8s4_linear_cutlass);
-}
-
-} // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h b/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h
deleted file mode 100644
index 991384b..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h
+++ /dev/null
@@ -1,288 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
-#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
-#include "cutlass/gemm/device/gemm_universal_adapter.h"
-
-#define CUTLASS_STATUS_CHECK(status) \
- { \
- TORCH_CHECK(status == cutlass::Status::kSuccess, \
- __func__, " : Got CUTLASS error: ", \
- cutlassGetStatusString(status)); \
- }
-
-namespace torchao {
-
-template<
- typename ThreadblockShape,
- typename WarpShape,
- typename InstructionShape,
- int NumStages,
- typename ThreadblockSwizzle,
- typename ElementA,
- typename ElementB,
- typename ElementAccumulator,
- typename Operator,
- typename ElementAScale,
- typename ElementBScale,
- typename ElementC,
- typename UseTensorC,
- typename ElementOutput>
-void scaled_linear_kernel_cutlass_sm8x(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- using SmArch = cutlass::arch::Sm80;
-
- using LayoutA = cutlass::layout::RowMajor;
- using LayoutB = cutlass::layout::ColumnMajor;
- using LayoutOutput = cutlass::layout::RowMajor;
-
- using ElementEpilogue = float;
- constexpr auto NumEVTEpilogueStages = 1;
-
- const int m = tensor_a.size(0);
- const int n = tensor_b.size(0);
- const int k = std::is_same<ElementA, cutlass::int4b_t>::value ?
- tensor_a.size(1) * 2 :
- tensor_a.size(1);
-
- constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
- constexpr int AlignmentAScale =
- 128 / cutlass::sizeof_bits<ElementAScale>::value;
- constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
- constexpr int AlignmentBScale =
- 128 / cutlass::sizeof_bits<ElementBScale>::value;
- constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
- constexpr int AlignmentOutput =
- 128 / cutlass::sizeof_bits<ElementOutput>::value;
-
- // Check for current CUTLASS limitations w.r.t. alignments.
- TORCH_CHECK(k % AlignmentA == 0,
- __func__, " : Number of columns of tensor A must be divisible ",
- "by ", AlignmentA);
- TORCH_CHECK(k % AlignmentB == 0,
- __func__, " : Number of columns of tensor B must be divisible ",
- "by ", AlignmentB);
- TORCH_CHECK(n % AlignmentC == 0,
- __func__, " : Number of columns of tensor C must be divisible ",
- "by ", AlignmentC);
-
- using TensorAScaleTileThreadMap =
- cutlass::epilogue::threadblock::OutputTileThreadLayout<
- ThreadblockShape,
- WarpShape,
- ElementAScale,
- AlignmentAScale,
- NumEVTEpilogueStages>;
- using TensorBScaleTileThreadMap =
- cutlass::epilogue::threadblock::OutputTileThreadLayout<
- ThreadblockShape,
- WarpShape,
- ElementBScale,
- AlignmentBScale,
- NumEVTEpilogueStages>;
- using TensorCTileThreadMap =
- cutlass::epilogue::threadblock::OutputTileThreadLayout<
- ThreadblockShape,
- WarpShape,
- ElementC,
- AlignmentC,
- NumEVTEpilogueStages>;
- using OutputTileThreadMap =
- cutlass::epilogue::threadblock::OutputTileThreadLayout<
- ThreadblockShape,
- WarpShape,
- ElementOutput,
- AlignmentOutput,
- NumEVTEpilogueStages>;
-
- using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
-
- using TensorAScale =
- cutlass::epilogue::threadblock::VisitorColBroadcast<
- TensorAScaleTileThreadMap,
- ElementAScale,
- cute::Stride<cute::_1, cute::_0, int64_t>>;
- using TensorAScaleArguments = typename TensorAScale::Arguments;
-
- using TensorBScale =
- cutlass::epilogue::threadblock::VisitorRowBroadcast<
- TensorBScaleTileThreadMap,
- ElementBScale,
- cute::Stride<cute::_0, cute::_1, int64_t>>;
- using TensorBScaleArguments = typename TensorBScale::Arguments;
-
- using TensorCScalar =
- cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
- using TensorCTensor =
- cutlass::epilogue::threadblock::VisitorRowBroadcast<
- TensorCTileThreadMap,
- ElementC,
- cute::Stride<cute::_0, cute::_1, int64_t>>;
- using TensorC =
- std::conditional_t<UseTensorC::value, TensorCTensor, TensorCScalar>;
- using TensorCArguments = typename TensorC::Arguments;
-
- using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, ElementEpilogue, ElementEpilogue,
- cutlass::FloatRoundStyle::round_to_nearest
- >;
- using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT<
- ApplyAScale,
- Accum,
- TensorAScale>;
-
- using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, ElementEpilogue, ElementEpilogue,
- cutlass::FloatRoundStyle::round_to_nearest
- >;
- using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT<
- ApplyBScale,
- EVTApplyAScale,
- TensorBScale>;
-
- using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::plus, ElementEpilogue, ElementEpilogue,
- cutlass::FloatRoundStyle::round_to_nearest
- >;
- using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
- ApplySum,
- EVTApplyBScale,
- TensorC>;
-
- using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
- OutputTileThreadMap, ElementOutput,
- cutlass::FloatRoundStyle::round_to_nearest,
- cute::Stride<int64_t, cute::_1, int64_t> // StrideMNL
- >;
-
- using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
- Output,
- EVTApplySum>;
-
- using EVTKernel =
- typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
- ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
- ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
- ElementOutput, LayoutOutput, AlignmentOutput,
- ElementAccumulator,
- ElementEpilogue,
- cutlass::arch::OpClassTensorOp,
- SmArch,
- ThreadblockShape,
- WarpShape,
- InstructionShape,
- EVTOutput,
- ThreadblockSwizzle,
- NumStages,
- Operator,
- NumEVTEpilogueStages
- >::GemmKernel;
-
- // GemmUniversalBase doesn't work with W4A4
- // using Gemm = cutlass::gemm::device::GemmUniversalBase<EVTKernel>;
- using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
-
- cutlass::gemm::GemmCoord problem_size(m, n, k);
- constexpr auto SplitKFactor = 1;
-
- TensorAScaleArguments tensor_a_scale_arguments{
- (ElementAScale*)tensor_a_scale.data_ptr(),
- ElementAScale(1),
- {cute::_1{}, cute::_0{}, problem_size.m()}
- };
- TensorBScaleArguments tensor_b_scale_arguments{
- (ElementBScale*)tensor_b_scale.data_ptr(),
- ElementBScale(1),
- {cute::_0{}, cute::_1{}, problem_size.n()}
- };
- TensorCArguments tensor_c_arguments{
- [&]() -> TensorCArguments {
- if constexpr (UseTensorC::value) {
- return {(ElementC*)tensor_c.data_ptr(),
- ElementC(0),
- {cute::_0{}, cute::_1{}, problem_size.n()}};
- } else {
- return {ElementC(0)};
- }
- }()
- };
- typename Output::Arguments output_arguments{
- (ElementOutput*)tensor_d.data_ptr(),
- {problem_size.n(), cute::_1{}, problem_size.mn().product()}
- };
- typename EVTOutput::Arguments callback_arguments{
- {
- {
- {
- {}, // Accum
- tensor_a_scale_arguments, // TensorAScale
- {} // ApplyAScale
- }, // EVTApplyAScale
- tensor_b_scale_arguments, // TensorBScale
- {}, // ApplyBScale
- }, // EVTApplyBScale
- tensor_c_arguments, // TensorC
- {} // ApplySum
- }, // EVTApplySum
- output_arguments // Output
- }; // EVTOutput
- // constexpr auto AvailSms = -1;
-
- typename Gemm::Arguments arguments(
- cutlass::gemm::GemmUniversalMode::kGemm,
- problem_size,
- SplitKFactor,
- callback_arguments, // arguments of EVT callbacks
- (ElementA*)tensor_a.data_ptr(),
- (ElementB*)tensor_b.data_ptr(),
- nullptr, // ptr C (unused)
- nullptr, // ptr D (unused)
- problem_size.mk().product(), // batch stride A
- problem_size.nk().product(), // batch stride B
- 0, // batch stride C (unused)
- 0, // batch stride D (unused)
- problem_size.k(), // stride A
- problem_size.k(), // stride B
- 0, // stride C (unused)
- 0
- // , // stride D (unused)
- // AvailSms // GemmUniversalBase requires passing AvailSms, but GemmUniversalAdapter doesn't
- );
-
- Gemm gemm_op;
-
- cutlass::Status status;
-
- // Verify that GEMM operation with given arguments can be performed
- // by CUTLASS.
- status = gemm_op.can_implement(arguments);
- CUTLASS_STATUS_CHECK(status);
-
- // Allocate workspace for CUTLASS mixed datatypes GEMM kernel.
- const auto workspace_size = Gemm::get_workspace_size(arguments);
- auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
- at::TensorOptions().dtype(at::kByte));
-
- // Initialize CUTLASS mixed datatypes GEMM object.
- status = gemm_op.initialize(arguments, workspace.data_ptr(),
- at::cuda::getCurrentCUDAStream());
- CUTLASS_STATUS_CHECK(status);
-
- // Perform mixed datatypes GEMM operation.
- status = gemm_op.run(at::cuda::getCurrentCUDAStream());
- CUTLASS_STATUS_CHECK(status);
-
- C10_CUDA_KERNEL_LAUNCH_CHECK();
-}
-
-} // namespace torchao
diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
index d7374c8..bccbf80 100644
--- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
+++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
@@ -144,14 +144,14 @@ def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias
def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
- from torchao.ops import s8s4_linear_cutlass
+ from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
weight = weight_tensor.tensor_impl.int_data
weight_scale = weight_tensor.tensor_impl.scale
input = input_tensor.tensor_impl.int_data
input_scale = input_tensor.tensor_impl.scale
- out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias)
+ out = rowwise_scaled_linear_cutlass_s8s4(input, input_scale, weight, weight_scale, bias)
return out
diff --git a/torchao/ops.py b/torchao/ops.py
index 840dbc0..272b358 100644
--- a/torchao/ops.py
+++ b/torchao/ops.py
@@ -20,11 +20,10 @@ lib.define(
"marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor"
)
lib.define(
- "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
+ "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
)
-lib.define("int4_mm_cutlass(Tensor A, Tensor B) -> Tensor")
lib.define(
- "scaled_int4_mm_cutlass(Tensor A, Tensor B, Tensor row_scale, Tensor col_scale) -> Tensor"
+ "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
)
@@ -518,7 +517,7 @@ def _(
return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device)
-def s8s4_linear_cutlass(
+def rowwise_scaled_linear_cutlass_s8s4(
input: Tensor,
input_scale: Tensor,
weight: Tensor,
@@ -526,23 +525,23 @@ def s8s4_linear_cutlass(
bias: Tensor,
) -> Tensor:
"""
- CUTLASS-based W4A8 linear operator.
+ CUTLASS-based row-wise scaled linear operator.
Args:
- input: input tensor, quantized to 8-bit integer values.
+ input: quantized input tensor, in row-major layout.
input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension.
- weight: weight matrix, quantized to 4-bit integer values, in row-major layout.
+ weight: quantized weight matrix, in row-major layout.
weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension).
bias: a vector of size equal to number of rows of weight tensor, or None.
Returns:
output: result tensor, in row-major layout.
"""
- return torch.ops.torchao.s8s4_linear_cutlass.default(
+ return torch.ops.torchao.rowwise_scaled_linear_cutlass_s8s4.default(
input, input_scale, weight, weight_scale, bias
)
-@register_custom_op("torchao::s8s4_linear_cutlass")
+@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s8s4")
def _(
input: Tensor,
input_scale: Tensor,
@@ -550,6 +549,8 @@ def _(
weight_scale: Tensor,
bias: Tensor,
) -> Tensor:
+ # FIXME: update this!!!
+
# Validate dtypes.
torch._check(
input.dtype == torch.int8,
@@ -621,29 +622,8 @@ def _(
)
-def int4_mm_cutlass(A: Tensor, B: Tensor) -> Tensor:
- """
- CUTLASS-based W4A4 matmul.
- Args:
- A: first INT4 tensor, packed in INT8 dtype, row-major layout.
- B: second INT4 tensor, packed in INT8 dtype, column-major layout.
- Returns:
- output: result tensor, in row-major layout.
- """
- assert A.dtype == B.dtype == torch.int8
- assert A.ndim == B.ndim == 2
- assert A.shape[1] == B.shape[0]
- assert A.is_contiguous() and B.T.is_contiguous()
- return torch.ops.torchao.int4_mm_cutlass.default(A, B)
-
-
-@register_custom_op("torchao::int4_mm_cutlass")
-def _(A: Tensor, B: Tensor) -> Tensor:
- return A.new_empty(A.shape[0], B.shape[1], dtype=torch.int32)
-
-
-def scaled_int4_mm_cutlass(
- A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor
+def rowwise_scaled_linear_cutlass_s4s4(
+ A: Tensor, row_scale: Tensor, B: Tensor, col_scale: Tensor, bias: Tensor
) -> Tensor:
"""
CUTLASS-based W4A4 scaled-matmul.
@@ -656,15 +636,16 @@ def scaled_int4_mm_cutlass(
output: result tensor, in row-major layout.
"""
assert A.dtype == B.dtype == torch.int8
- assert A.ndim == B.ndim == 2
- assert A.shape[1] == B.shape[0]
- assert A.is_contiguous() and B.T.is_contiguous()
- assert row_scale.ndim == col_scale.ndim == 1
+ assert A.ndim >= 2
+ assert B.ndim == 2
+ assert A.shape[-1] == B.shape[-1]
+ assert A.is_contiguous() and B.is_contiguous()
+ assert row_scale.ndim == col_scale.ndim == 2
assert row_scale.dtype == col_scale.dtype
assert row_scale.dtype in (torch.float16, torch.bfloat16)
- return torch.ops.torchao.scaled_int4_mm_cutlass.default(A, B, row_scale, col_scale)
+ return torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4.default(A, row_scale, B, col_scale, bias)
-@register_custom_op("torchao::scaled_int4_mm_cutlass")
-def _(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor:
+@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s4s4")
+def _(A: Tensor, row_scale: Tensor, B: Tensor, col_scale: Tensor, bias: Tensor) -> Tensor:
return row_scale.new_empty(A.shape[0], B.shape[1]) I did some renaming, so Please take a look and let me know what you think. Regarding your question about possibility of different types for scales and bias - this pretty much comes for free, and I would be inclined to extend this further so that either of input/weight scales and/or bias could be of different types. On the other side, I have a related question here: looking at your My TODO list regarding the attached patch:
I'll continue on this tomorrow. |
Thank you for the patch! Will apply and work on it later today.
I don't think there is a standalone real use case for INT4xINT4->INT32 (yet), similarly for INT8xINT8->INT32 i.e.
From my observation, it's a good idea to have input checks for the meta impletation used by torch.compile for shape tracing (otherwise we won't catch errors during compile time, only runtime) -> we need some kind of input checks in Python anyway. Hence, usually I don't bother with input checks in C++ (or triton) side, and put all checks in Python. I place the Python checks in the Python wrapper so that it is shared between meta and CUDA implementations (actually it's regardless of the implementation) -> technically the op doesn't have input checks, but the Python wrapper does. e.g. ao/torchao/prototype/quantized_training/int8_mm.py Lines 137 to 150 in 5d1444b
That's just my personal view (and I prefer writing Python over C++ any time 😆). |
Found a typo + int k = tensor_a.size(1);
+ if constexpr (cutlass::sizeof_bits<ElementA>::value < 8) {
+ k *= 8 % cutlass::sizeof_bits<ElementA>::value;
+ } Should be (integer) division instead of modulo. After I fixed that, the tests pass. I also take the liberty to slightly change the test: now matmul, scaling, and bias addition are done in FP32, which should match the numerics of the fused cutlass kernel better, especially for the scaling and bias addition parts, since FP16/BF16 matmul is acccumulated with FP32 anyway. With this change, we can use With your refactoring, it should be trivial to add W8A8. Should we add W8A8 also? (We might not add it to AQT since AQT currently uses torch.compile to codegen this path. Though I think inductor only fuses 1 scaling (either input scale or weight scale), and we can get a bit more perf with fusing both input and weight scales. Other ideas:
|
def benchmark_microseconds(f, *args): | ||
return do_bench(lambda: f(*args), return_mode="median") * 1e3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexsamardzic I replaced torchao.utils.benchmark_torch_function_in_microseconds
(which is based on torch.utils.benchmark.Timer
) with triton's do_bench
. This is because I found PyTorch timer is unreliable, possibly because it does not clear L2 cache in between runs.
Old (4090, torch.utils.benchmark.Timer
)
m | k | n | fp16_latency (ms) | W4A8 latency (ms) | W4A8 speedup (d/s) |
---|---|---|---|---|---|
1 | 8192 | 8192 | 124.368 | 18.6586 | 6.66545 |
1 | 8192 | 10240 | 177.832 | 22.1851 | 8.01584 |
1 | 8192 | 57344 | 982.894 | 298.987 | 3.28742 |
1 | 28672 | 8192 | 493.988 | 152.724 | 3.23452 |
New (4090, triton.testing.do_bench
)
m | k | n | fp16_latency (ms) | W4A8 latency (ms) | W4A8 speedup (d/s) |
---|---|---|---|---|---|
1 | 8192 | 8192 | 166.912 | 73.728 | 2.26389 |
1 | 8192 | 10240 | 201.824 | 87.04 | 2.31875 |
1 | 8192 | 57344 | 1005.57 | 346.112 | 2.90533 |
1 | 28672 | 8192 | 519.136 | 199.68 | 2.59984 |
In the old way, you can see the unusual speedup for the first two rows. I think it's because the W is cached in L2, hence the gains disappear when W becomes larger.
Lmk if it's ok to have this change. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that's great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For anyone else I break it down as:
- Do you only care about kernel time? Then use: https://github.com/drisspg/transformer_nuggets/blob/0d1cbe6b3d980451f39174be64c98a143b9cebce/transformer_nuggets/utils/benchmark.py#L55. This is what I have found produces the closest results to NCU
- If the overhead does matter and you are more focused on what you would see on a an isolated region of code then the timer is pretty representative https://github.com/drisspg/transformer_nuggets/blob/0d1cbe6b3d980451f39174be64c98a143b9cebce/transformer_nuggets/utils/benchmark.py#L44
Thanks @gau-nernst, AQT integration code LGTM |
Would it be possible to apply this patch too: patch.txt? Just learned hard way that this is how it's supposed to work... Also, I'm wondering do we need to add some compile/run time checks against anything that is less than SM80? |
@alexsamardzic Just merged your changes. I think initially I wrote it that way too, but didn't think it would make a big difference. For runtime check, you already have one in place template<typename ElementA, typename ElementB, typename... Types>
static void select_config(
const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
const at::Tensor& tensor_c, at::Tensor& tensor_d) {
const auto dprops = at::cuda::getCurrentDeviceProperties();
const auto is_sm8x = dprops->major == 8;
if (is_sm8x) {
// early return
}
TORCH_CHECK(false,
__func__, " : Operator not supported on SM", dprops->major, ".",
dprops->minor, " for given operands");
} For compile time, we officially only support >=sm80 (and only build wheels for >=sm80. Also we can't test for <sm80 anyway since CI doesn't have such GPUs) ao/packaging/env_var_script_linux.sh Lines 14 to 19 in 3eb18e7
However, there have been some efforts to make sure torchao can still be compiled on sm75 e.g. #1147. My understanding is that, iirc:
I'm leaning towards adding Btw, if it's not obvious, the PR is ready for review. I won't be making any more major changes (unless requested). Happy to receive feedback. |
The difference is when an epilogue tensor is of different data type from the output tensor - then the initial variant will not work properly. I was too thinking all the time that this initial variant, with a tile thread map for each of tensors used by epilogue, is a way to go - just learned hard time, while debugging my other PR (that is extending row-wise scaling MM in implementation PyTorch) that it isn't, and that there should be only one tile thread map instead. (FWIW, found that this is the right way while looking into the code generated by CUTLASS Python library, which always generates single tile thread map.)
Yes, but I mentioned compile time check as actually was looking yesterday into a build issue on SM75 reported by user, and noticed that the error message prints
Oh, that's good to know - I was not aware. Still, I've encountered recently a nifty trick in vLLM: check the definition of Let me know what you think on these two - if you agree, I can make a patch. Other than that, I'm approving the PR (as far as CUDA code concerned, was not looking into the rest). This is awesome, let's have it merged ASAP. |
For the runtime check, either suggestion sounds good! For compile-time check, iiuc, I will apply your patch when it is ready. You should have authored this PR. All major changes are done by you 😅 |
Here is the patch: patch.txt. |
@drisspg When you are back, can you take another look, then I think we can merge. Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! One nit
Closes #1406
Thanks to #880, we now have a CUTLASS (3.6.0) copy in torchao. Adding W4A4 is pretty straight-forward, similar to how W4A8 is done.
This is largely copied from my other repo, so I didn't exactly follow @alexsamardzic's style. Requesting a first round of review.Note: this is more for doing experiments with W4A4 easier. Personally I don't think it's too useful at the moment, since W4A4 accuracy is probably quite bad.
TODO:
A100 SXM4 (from vast.ai)
4090 (from vast.ai)
Right now kernel params are not tuned, hence bad performance especially at small batch sizes.Update: added some tuning