Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUTLASS-based W4A4 #1515

Merged
merged 33 commits into from
Feb 5, 2025
Merged

Add CUTLASS-based W4A4 #1515

merged 33 commits into from
Feb 5, 2025

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Jan 7, 2025

Closes #1406

Thanks to #880, we now have a CUTLASS (3.6.0) copy in torchao. Adding W4A4 is pretty straight-forward, similar to how W4A8 is done. This is largely copied from my other repo, so I didn't exactly follow @alexsamardzic's style. Requesting a first round of review.

  • Update: Thanks to @alexsamardzic's refactoring efforts, now W4A8 and W4A4 share most of the code, with template instantiation for each GEMM type.

Note: this is more for doing experiments with W4A4 easier. Personally I don't think it's too useful at the moment, since W4A4 accuracy is probably quite bad.

TODO:

  • Hook up to AQT
  • Benchmark script and get benchmark results for A100 and 4090
  • (Maybe) Do some tuning + heuristics based on problem size
A100 SXM4 (from vast.ai)
m k n fp16_latency (ms) W4A8 latency (ms) W4A8 speedup (d/s) W4A4 latency (ms) W4A4 speedup (d/s)
1 8192 8192 110.592 51.2 2.16 38.912 2.84211
1 8192 10240 123.904 56.32 2.2 45.056 2.75
1 8192 57344 569.344 206.848 2.75248 156.672 3.63399
1 28672 8192 299.008 107.52 2.78095 92.16 3.24444
2 8192 8192 108.544 50.176 2.16327 38.912 2.78947
2 8192 10240 129.024 57.344 2.25 45.056 2.86364
2 8192 57344 581.632 206.848 2.81188 157.696 3.68831
2 28672 8192 293.888 107.52 2.73333 92.16 3.18889
4 8192 8192 109.568 50.176 2.18367 38.912 2.81579
4 8192 10240 130.048 57.344 2.26786 45.056 2.88636
4 8192 57344 585.728 206.848 2.83168 158.72 3.69032
4 28672 8192 294.912 107.52 2.74286 92.16 3.2
8 8192 8192 109.568 50.176 2.18367 38.912 2.81579
8 8192 10240 131.072 57.344 2.28571 45.056 2.90909
8 8192 57344 591.872 206.848 2.86139 159.744 3.70513
8 28672 8192 294.912 107.52 2.74286 92.16 3.2
16 8192 8192 109.568 50.176 2.18367 38.912 2.81579
16 8192 10240 133.12 57.344 2.32143 46.08 2.88889
16 8192 57344 601.088 206.848 2.90594 161.792 3.71519
16 28672 8192 295.936 108.544 2.72642 92.16 3.21111
32 8192 8192 109.568 52.224 2.09804 39.936 2.74359
32 8192 10240 126.976 58.368 2.17544 47.104 2.69565
32 8192 57344 602.112 214.016 2.8134 166.912 3.60736
32 28672 8192 306.176 112.64 2.71818 95.232 3.21505
64 8192 8192 109.568 55.296 1.98148 41.984 2.60976
64 8192 10240 124.928 63.488 1.96774 49.152 2.54167
64 8192 57344 590.848 234.496 2.51965 179.2 3.29714
64 28672 8192 313.344 126.976 2.46774 98.304 3.1875
128 8192 8192 129.024 82.944 1.55556 56.32 2.29091
128 8192 10240 140.288 88.064 1.59302 57.344 2.44643
128 8192 57344 694.272 387.072 1.79365 209.92 3.30732
128 28672 8192 328.704 214.016 1.53589 137.216 2.39552
256 8192 8192 164.864 118.784 1.38793 65.536 2.51563
256 8192 10240 207.872 147.456 1.40972 69.632 2.98529
256 8192 57344 988.16 736.256 1.34214 301.056 3.28231
256 28672 8192 465.92 346.112 1.34615 188.416 2.47283
512 8192 8192 303.104 228.352 1.32735 119.808 2.52991
512 8192 10240 330.752 271.36 1.21887 119.808 2.76068
512 8192 57344 1839.1 1422.34 1.29302 512 3.592
512 28672 8192 912.384 720.896 1.26562 362.496 2.51695
4090 (from vast.ai)
m k n fp16_latency (ms) W4A8 latency (ms) W4A8 speedup (d/s) W4A4 latency (ms) W4A4 speedup (d/s)
1 8192 8192 166.912 73.728 2.26389 55.296 3.01852
1 8192 10240 201.824 87.04 2.31875 67.584 2.98627
1 8192 57344 1005.57 346.112 2.90533 315.392 3.18831
1 28672 8192 519.136 199.68 2.59984 175.104 2.96473
2 8192 8192 205.824 73.728 2.79167 55.296 3.72222
2 8192 10240 243.712 87.04 2.8 67.584 3.60606
2 8192 57344 1067.01 346.112 3.08284 315.392 3.38312
2 28672 8192 555.008 200.704 2.76531 175.104 3.16959
4 8192 8192 206.72 73.728 2.80382 55.296 3.73843
4 8192 10240 243.712 87.04 2.8 67.584 3.60606
4 8192 57344 1069.18 347.136 3.08001 316.416 3.37905
4 28672 8192 556.032 200.704 2.77041 175.104 3.17544
8 8192 8192 206.848 73.728 2.80556 55.296 3.74074
8 8192 10240 243.712 87.168 2.79589 67.584 3.60606
8 8192 57344 1073.15 348.16 3.08235 317.44 3.38064
8 28672 8192 556.032 200.704 2.77041 175.104 3.17544
16 8192 8192 206.848 73.728 2.80556 55.296 3.74074
16 8192 10240 243.712 87.072 2.79897 68.32 3.56721
16 8192 57344 1079.3 349.184 3.09091 318.464 3.38907
16 28672 8192 559.104 200.704 2.78571 175.104 3.19298
32 8192 8192 207.872 74.752 2.78082 56.32 3.69091
32 8192 10240 263.168 90.048 2.92253 68.512 3.8412
32 8192 57344 1147.9 352.256 3.25872 319.488 3.59295
32 28672 8192 570.368 202.944 2.81047 178.176 3.20115
64 8192 8192 209.92 76.8 2.73333 57.344 3.66071
64 8192 10240 260.096 93.184 2.79121 69.632 3.73529
64 8192 57344 1210.37 362.336 3.34046 324.608 3.72871
64 28672 8192 603.136 207.872 2.90148 180.224 3.34659
128 8192 8192 230.4 82.944 2.77778 58.368 3.94737
128 8192 10240 280.576 98.304 2.85417 74.752 3.75342
128 8192 57344 1269.76 373.76 3.39726 330.752 3.83901
128 28672 8192 651.264 225.056 2.89379 183.04 3.55804
256 8192 8192 278.528 109.568 2.54206 66.56 4.18462
256 8192 10240 336.896 137.216 2.45522 81.92 4.1125
256 8192 57344 1600.51 609.28 2.62689 362.496 4.41525
256 28672 8192 816.128 308.224 2.64784 211.776 3.85373
512 8192 8192 432.128 191.488 2.25668 71.68 6.02857
512 8192 10240 549.888 234.496 2.34498 121.856 4.51261
512 8192 57344 2992.13 1170.46 2.55636 441.344 6.77958
512 28672 8192 1542.14 565.248 2.72826 223.456 6.90133

Right now kernel params are not tuned, hence bad performance especially at small batch sizes. Update: added some tuning

Copy link

pytorch-bot bot commented Jan 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1515

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 4613503 with merge base 7e54629 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 7, 2025
@gau-nernst gau-nernst added the topic: new feature Use this tag if this PR adds a new feature label Jan 7, 2025
@alexsamardzic
Copy link
Collaborator

CUDA code looks fine, of course there are lots of dots to connect remaining on the Python side.

The difference from #880 is that this is not mixed data types GEMM, but regular GEMM instead. In that regard, this operator here is maybe easier to be made much more generic, to support other integer and maybe even some floating point input data types. I'm at the moment making some minor changes on this PyTorch operator, and would strongly recommend modelling CUDA code in alike way, as it plain looks nice, and then makes extending the kernel to other datatypes much easier, has extensive checks on operands, etc. Moreover, I think it would make sense at this point to discuss having a single CUTLASS-based kernel for GEMMs with both weights and activations scaled, to be put in the single source file, and to handle both same and mixed data types GEMMs, at least for SM 8.x archs - that would provide for minimum code duplication, and easier maintenance in the future.

As far as configurations (tile sizes, number of stages, etc.) concerned, I'd suggest looking here instead in the unit tests, and also comparing performance vs. results reported by CUTLASS profiler for given combination of data types. I believe some sort of tuning configuration on the input shapes is a must in order to achieve a decent performance; but I have to admit that in #880 the tuning is mostly ad-hoc (for comparison, I find this approach more elaborate and meaningful). Thus, I think that coming up with some kind of systematic approach in that regard would be the most beneficial contribution regarding eventual future use of CUTLASS-based kernels in the torchao.

(@drisspg: Your comments welcome here.)

@drisspg
Copy link
Contributor

drisspg commented Jan 7, 2025

One thing on finding optimal params is that @yifuwang was recently working on finding better configs for an AsyncMM. He did some manual elimination of configs that never seemed to be performant and then fit a simple decision Tree on a big sweep over MKN shapes that could be easily modeled in C++. This is similar to what is done in the RowWise scaling. I think a little flow for this would be helpful I can make an issue to track.

No major comments

@gau-nernst
Copy link
Collaborator Author

Thank you for the feedback.

In that regard, this operator here is maybe easier to be made much more generic, to support other integer and maybe even some floating point input data types

Though this is nice on paper, I think Triton is the better alternative for other data types (INT8, FP8...). It's more flexible and the autotuner also saves us some headache. Only because of the lack of INT4 support in Triton, we have to use Cutlass, especially for INT4 Tensor cores. Unless we can show that there are cases Triton cannot reach the perf of Cutlass (in the context of this PR, I'm only thinking about INT8 for SM8x, and additionally FP8 for SM89).

Having said that, I'm ok with following a certain style/structure. Just point me which one it should be, and I will make modifications accordingly.

torchao/ops.py Outdated
Returns:
output: result tensor, in row-major layout.
"""
assert A.dtype == B.dtype == torch.int8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add the alignment constraints as well right?

Copy link
Collaborator Author

@gau-nernst gau-nernst Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should I check for data alignment from Python? I guess in C++, I can check by testing divisibility of the memory address? (or perhaps there is a util function somewhere that I'm not aware of...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think there is a restriction that k need to be a multiple of 32 right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or at least 16 packed int4 s

using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
// static int const kStages = 3;
using ElementC = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if the universal gemm api can be used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into it. I wrote this quite some time ago...

@alexsamardzic
Copy link
Collaborator

@gau-nernst

Attached is a minor patch that will change s8s4_linear_cutlass operator to do W4A4. I did it in a quick-a-dirty way, so maybe I made some kind of error (the tests run, but won't pass), but the point is that the main differences are in: checking arguments (I just commented out these), dispatching depending on input data types (I just changed the input data type in this patch instead of having full dispatching) and selecting configuration (for W4A4 I just put the same configuration there that you use in your patch). But the CUTLASS boilerplate code is completely the same, except for using OpMultiplyAddMixedInputUpcast as operator for mixed input data types case. So my point is that I think, for maintenance reasons, we better have single CUTLASS-based kernel for W4A8 and W4A4 (and then all of alike) instead of duplicating lots of code.

The structure of CUTLASS-based kernels is typically always the same (see also rowwise scaled MM in PyTorch, mentioned in my previous comment, as well as my CUTLASS-based mixed data types and 2:4 sparsity kernels in PyTorch): from the bottom up, there is always an operator implementation function that contains checking inputs, and then starting a dispatching chain (where run-time data types etc. are translated to compile-time template arguments), that ends up with a typical CUTLASS-based GEMM kernel (that is boilerplate). Also as mentioned in my previous comment, while rowwise scaled MM is very similar in structure, I like how it looks the most - because of clever use of variable template arguments to decrease the clutter, then because of clear extraction of input checks, and configuration selection into separate functions, etc. So I'd suggest we have your C++ code integrated in the way sketched by attached diff, and then also to made minor changes in the C++ code in a way to make it to look closer to rowwise scaled MM implementation. (Of course, operator name and some other stuff on Python side will have to be changed too.)

diff.txt

@alexsamardzic
Copy link
Collaborator

alexsamardzic commented Jan 8, 2025

As far as performance between various implementations concerned: I'd say in general there are three ways to implement kernels: Triton-based, CUTLASS-based, and custom i.e. from scratch (like Marlin-based kernels). In my experience so far (that was all for Ampere arch), CUTLASS-based kernels are oftentimes somewhat faster than Triton-based kernels, while then for some corner-case input tensor sizes, custom kernels (well, Marlin-based at least) could be significantly faster than CUTLASS-based ones. Furthermore, with Triton there is the least amount of flexibility with upstream changes (they just don't support some input data types, they don't support 2:4 sparsity, etc.), with CUTLASS it's somewhat easier to have changes we may need accepted, while for custom kernels obviously this is not an issue at all. However, Triton kills it when it comes to compilation, in particular regarding fusing GEMM with other kernels, then CUTLASS has some support for compilation but doing fusion is rather cumbersome at the moment, while obviously there is no any kind of compilation support for custom kernels. Then, doing custom kernels would probably lead to lots of code duplication, with CUTLASS this also may be an issue even if to the smaller extent. Etc. - so it's all matter of trade-offs. Still, having in mind auto-tuning and auto-quantization, I belive it still may be good to have as much different kernels in torchao as possible, so I'd expect more CUTLASS-based kernels to be written, besides these W4A8 and W4A4 kernels - and this is the exact reason that, as discussed above, I'd prefer to have as much code shared as possible between these kernels.

@supriyar
Copy link
Contributor

supriyar commented Jan 9, 2025

since W4A4 accuracy is probably quite bad.

Might be interesting to try out QAT with this setting cc @andrewor14

@alexsamardzic
Copy link
Collaborator

The structure of CUTLASS-based kernels is typically always the same (see also rowwise scaled MM in PyTorch, mentioned in my previous comment, as well as my CUTLASS-based mixed data types and 2:4 sparsity kernels in PyTorch): from the bottom up, there is always an operator implementation function that contains checking inputs, and then starting a dispatching chain (where run-time data types etc. are translated to compile-time template arguments), that ends up with a typical CUTLASS-based GEMM kernel (that is boilerplate). Also as mentioned in my previous comment, while rowwise scaled MM is very similar in structure, I like how it looks the most - because of clever use of variable template arguments to decrease the clutter, then because of clear extraction of input checks, and configuration selection into separate functions, etc. So I'd suggest we have your C++ code integrated in the way sketched by attached diff, and then also to made minor changes in the C++ code in a way to make it to look closer to rowwise scaled MM implementation. (Of course, operator name and some other stuff on Python side will have to be changed too.)

I've made these changes to existing CUTLASS-based W4A8 kernel in #1545, so it should be easier now to eventually include W4A4 functionality there.

@gau-nernst
Copy link
Collaborator Author

@alexsamardzic Thank you for your prompt feedback. Was going to work on this a little bit more before asking for another round of review.

I wasn't sure how much code-sharing you intend to have. Was trying to work with single file, but the biggest blocker is this one

  if (tensor_a.scalar_type() == at::ScalarType::Char) {
    if (tensor_b.scalar_type() == at::ScalarType::Char) {
      if (tensor_a.size(1) == 2 * tensor_b.size(1)) {

In dispatch_on_tensor_a_and_tensor_b(), you rely on tensor_a.size(1) == 2 * tensor_b.size(1) to detect W4A8. It means that there is no meaningful way to distinguish W4A4 and W8A8, unless there is an extra template argument, which kinda defeats the purpose of dispatch_on_tensor_a_and_tensor_b() since we should just call select_config() directly. Would require your advice on this.

That's why I decided to move the inner-most template to a separate header file for sharing. Wanted to move select_config() too, but since it's about hard-coding kernel params for a particular GEMM, I felt it should stay in GEMM-specific file.

I have another question. Currently you have dispatch_on_tensor_a_scale_and_tensor_b_scale() and dispatch_on_tensor_c(), which seems to suggest that you allow bias dtype != tensor scale dtype. Currently only FP16 and BF16 permutations are supported. I don't know if BF16 tensor scales + FP16 bias is useful/necessary? (and vice versa for FP16 tensor scales and BF16 bias). Furthermore, torch._scaled_mm() uses FP32 scales and allow specifying output dtype. Not exactly sure if FP32 scales are necessary, but just pointing out my observation here (might want to support FP32 scales also? Then perhaps it makes sense to support bias dtype != tensor scale dtype, where scale dtype = FP32 and bias dtype = FP16/BF16).

check_inputs() will also require some ways to tell that either A or B (or both) are int4, if you want to keep it single file.

Side comments. I don't mind code duplication. Realistically speaking, there are not that many useful combinations, so code duplication is not too bad. And you probably already know, for cutlass device-level API, not all combinations work (and when they don't work, there are mysterious errors 😅)

@alexsamardzic
Copy link
Collaborator

I realized, while making changes, that your point on dispatching on input/weight tensor data types is right. Namely, PyTorch doesn't provide sub-byte data types, so the only way to differentiate between data type combinations like S4/S4 and S8/S8 is to have this information encoded somehow else. Your approach was to encode it in the name of the operator, and this kind of approach could also solve the issue that I was worried about - that not all templates get instantiated from the single source file.

I made some further changes to maximize C++ code reuse. I don't mind code duplication either, but in this case it's (in my opinion) about really cumbersome boilerplate code, that for sanity I'd really much prefer to keep on single place. So I came up with a following patch, to be applied on top of current state of this branch:

Patch
diff --git a/benchmarks/benchmark_s8s4_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
similarity index 73%
rename from benchmarks/benchmark_s8s4_cutlass.py
rename to benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
index fbf07eb..00bcb0a 100644
--- a/benchmarks/benchmark_s8s4_cutlass.py
+++ b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
@@ -2,7 +2,7 @@ import pandas as pd
 import torch
 from tqdm import tqdm
 
-from torchao.ops import s8s4_linear_cutlass
+from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
 from torchao.utils import benchmark_torch_function_in_microseconds
 
 
@@ -24,8 +24,8 @@ def benchmark(m: int, k: int, n: int):
     A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k)
 
     fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
-    s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds(
-        s8s4_linear_cutlass, A, A_scale, B, B_scale, C
+    rowwise_scaled_linear_cutlass_s8s4_time = benchmark_torch_function_in_microseconds(
+        rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
     )
 
     return {
@@ -33,8 +33,8 @@ def benchmark(m: int, k: int, n: int):
         "k": k,
         "n": n,
         "fp16_latency (ms)": fp16_time,
-        "s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time,
-        "speedup (d/s)": fp16_time / s8s4_linear_cutlass_time,
+        "rowwise_scaled_linear_cutlass latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
+        "speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
     }
 
 
@@ -48,5 +48,5 @@ if __name__ == "__main__":
             results.append(benchmark(m, k, n))
 
     df = pd.DataFrame(results)
-    df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False)
+    df.to_csv("rowwise_scaled_linear_cutlass_s8s4_time_results.csv", index=False)
     print(df.to_markdown(index=False))
diff --git a/test/test_rowwise_scaled_linear_cutlass.py b/test/test_rowwise_scaled_linear_cutlass.py
new file mode 100644
index 0000000..422b100
--- /dev/null
+++ b/test/test_rowwise_scaled_linear_cutlass.py
@@ -0,0 +1,128 @@
+import itertools
+
+import pytest
+import torch
+
+from torchao.ops import (
+    rowwise_scaled_linear_cutlass_s4s4,
+    rowwise_scaled_linear_cutlass_s8s4,
+)
+from torchao.quantization.utils import group_quantize_tensor_symmetric
+from torchao.utils import compute_max_diff
+
+ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
+ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
+ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
+    (2, 512, 128),
+    (3, 2048, 2048),
+    (4, 3584, 640),
+    (13, 8704, 8576),
+    (26, 18944, 1664),
+    (67, 6656, 1408),
+]
+ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True]
+ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list(
+    itertools.product(
+        ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE,
+        ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE,
+        ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK,
+        ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS,
+    )
+)
+
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
+    "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
+)
+def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
+    size_m, size_n, size_k = size_mnk
+
+    input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
+    weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
+    bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
+
+    input_2d = input.view(-1, input.shape[-1])
+    input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
+        input_2d, 4, size_k, dtype
+    )
+    assert torch.all(input_2d_zeros == 0)
+    input_s8 = input_2d_s8.reshape(input.shape)
+    input_s4 = ((input_s8[:, :, 1::2] & 0xF) << 4) | (input_s8[:, :, 0::2] & 0xF)
+    input_scales = input_2d_scales.reshape(input.shape[:-1])
+
+    weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
+        weight, 4, size_n, dtype
+    )
+    assert torch.all(weight_zeros == 0)
+    weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
+
+    # If torch.nn.functional.linear(input, weight, bias) used as
+    # reference, the error would be too big.  The calculation below is
+    # approximately what rowwise_scaled_linear_cutlass kernel is doing
+    # (except that matrix multiplication is over integers there)).
+    size_m_2d = input_2d.shape[0]
+    output_ref = (
+        (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
+        * input_2d_scales.view(size_m_2d, 1)
+        * weight_scales.view(1, size_n)
+    )
+    if bias is not None:
+        output_ref += bias
+    output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
+
+    fn_inputs = (input_s4, input_scales, weight_s4, weight_scales, bias)
+    try:
+        output = rowwise_scaled_linear_cutlass_s4s4(*fn_inputs)
+    except NotImplementedError:
+        pytest.xfail("rowwise_scaled_linear_cutlass() op not implemented")
+
+    max_diff = compute_max_diff(output, output_ref)
+    assert max_diff < 1e-3
+
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
+    "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
+)
+def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
+    size_m, size_n, size_k = size_mnk
+
+    input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
+    weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
+    bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
+
+    input_2d = input.view(-1, input.shape[-1])
+    input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
+        input_2d, 8, size_k, dtype
+    )
+    assert torch.all(input_2d_zeros == 0)
+    input_s8 = input_2d_s8.reshape(input.shape)
+    input_scales = input_2d_scales.reshape(input.shape[:-1])
+
+    weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
+        weight, 4, size_n, dtype
+    )
+    assert torch.all(weight_zeros == 0)
+    weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
+
+    # If torch.nn.functional.linear(input, weight, bias) used as
+    # reference, the error would be too big.  The calculation below is
+    # approximately what rowwise_scaled_linear_cutlass kernel is doing
+    # (except that matrix multiplication is over integers there)).
+    size_m_2d = input_2d.shape[0]
+    output_ref = (
+        (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
+        * input_2d_scales.view(size_m_2d, 1)
+        * weight_scales.view(1, size_n)
+    )
+    if bias is not None:
+        output_ref += bias
+    output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
+
+    fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
+    try:
+        output = rowwise_scaled_linear_cutlass_s8s4(*fn_inputs)
+    except NotImplementedError:
+        pytest.xfail("rowwise_scaled_linear_cutlass() op not implemented")
+
+    max_diff = compute_max_diff(output, output_ref)
+    assert max_diff < 5e-3
diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py
deleted file mode 100644
index 6510ada..0000000
--- a/test/test_s8s4_linear_cutlass.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import itertools
-
-import pytest
-import torch
-
-from torchao.ops import s8s4_linear_cutlass
-from torchao.quantization.utils import group_quantize_tensor_symmetric
-from torchao.utils import compute_max_diff
-
-S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
-S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
-S8S4_LINEAR_CUTLASS_SIZE_MNK = [
-    (2, 512, 128),
-    (3, 2048, 2048),
-    (4, 3584, 640),
-    (13, 8704, 8576),
-    (26, 18944, 1664),
-    (67, 6656, 1408),
-]
-S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True]
-S8S4_LINEAR_CUTLASS_TEST_PARAMS = list(
-    itertools.product(
-        S8S4_LINEAR_CUTLASS_DTYPE,
-        S8S4_LINEAR_CUTLASS_BATCH_SIZE,
-        S8S4_LINEAR_CUTLASS_SIZE_MNK,
-        S8S4_LINEAR_CUTLASS_USE_BIAS,
-    )
-)
-
-
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
-    "dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS
-)
-def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias):
-    size_m, size_n, size_k = size_mnk
-
-    input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
-    weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
-    bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
-
-    input_2d = input.view(-1, input.shape[-1])
-    input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
-        input_2d, 8, size_k, dtype
-    )
-    assert torch.all(input_2d_zeros == 0)
-    input_s8 = input_2d_s8.reshape(input.shape)
-    input_scales = input_2d_scales.reshape(input.shape[:-1])
-
-    weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
-        weight, 4, size_n, dtype
-    )
-    assert torch.all(weight_zeros == 0)
-    weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
-
-    # If torch.nn.functional.linear(input, weight, bias) used as
-    # reference, the error would be too big.  The calculation below is
-    # approximately what s8s4_linear_cutlass kernel is doing (except
-    # that matrrix multiplication is over integers there)).
-    size_m_2d = input_2d.shape[0]
-    output_ref = (
-        (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
-        * input_2d_scales.view(size_m_2d, 1)
-        * weight_scales.view(1, size_n)
-    )
-    if bias is not None:
-        output_ref += bias
-    output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
-
-    fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
-    try:
-        output = s8s4_linear_cutlass(*fn_inputs)
-    except NotImplementedError:
-        pytest.xfail("s8s4_linear_cutlass() op not implemented")
-
-    max_diff = compute_max_diff(output, output_ref)
-    assert max_diff < 5e-3
diff --git a/torchao/csrc/cuda/int4_cutlass.cu b/torchao/csrc/cuda/int4_cutlass.cu
deleted file mode 100644
index 452abcc..0000000
--- a/torchao/csrc/cuda/int4_cutlass.cu
+++ /dev/null
@@ -1,231 +0,0 @@
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-
-// copied from s8s4_linear_cutlass.cu
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) &&                   \
-    defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_INT4_MM_CUTLASS
-#endif
-
-#if defined(BUILD_INT4_MM_CUTLASS)
-#include "cutlass/cutlass.h"
-#include "cutlass/gemm/device/gemm_universal.h"
-#include "cutlass/gemm/device/gemm.h"
-#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
-#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
-#include "cutlass/gemm/device/gemm_universal_adapter.h"
-
-#define CUTLASS_STATUS_CHECK(status)                                      \
-  {                                                                       \
-    TORCH_CHECK(status == cutlass::Status::kSuccess,                      \
-                __func__, " : Got CUTLASS error: ",                       \
-                cutlassGetStatusString(status));                          \
-  }
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_INT4_MM_CUTLASS)
-// define common params
-using ElementA           = cutlass::int4b_t;
-using ElementB           = cutlass::int4b_t;
-using ElementAccumulator = int32_t;
-using OpClass            = cutlass::arch::OpClassTensorOp;
-using ArchTag            = cutlass::arch::Sm80;
-
-// how many elements to load at a time -> load 128-bit = 32 x 4-bit
-constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
-constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
-#endif
-
-// we will do input checks in python. A and B are stored as int8
-torch::Tensor int4_mm_cutlass(torch::Tensor A, torch::Tensor B) {
-#if defined(BUILD_INT4_MM_CUTLASS)
-  int M = A.size(0);
-  int K = A.size(1) * 2;
-  int N = B.size(1);
-  torch::Tensor C = torch::empty({M, N}, A.options().dtype(torch::kInt32));
-
-  // some configs for int4 mma
-  // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
-  // using default config. this can be tuned.
-  using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
-  using WarpShape        = cutlass::gemm::GemmShape<64, 64, 128>;
-  using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
-  // static int const kStages = 3;
-  using ElementC = int32_t;
-  using Gemm = cutlass::gemm::device::Gemm<
-    ElementA, cutlass::layout::RowMajor,    // A matrix
-    ElementB, cutlass::layout::ColumnMajor, // B matrix
-    ElementC, cutlass::layout::RowMajor,    // C matrix
-    ElementAccumulator, OpClass, ArchTag,
-    ThreadblockShape, WarpShape, InstructionShape
-  >;
-  Gemm::Arguments args {
-    {M, N, K},
-    {reinterpret_cast<ElementA *>(A.data_ptr<int8_t>()), K},
-    {reinterpret_cast<ElementB *>(B.data_ptr<int8_t>()), K},
-    {C.data_ptr<ElementC>(), N},
-    {C.data_ptr<ElementC>(), N},
-    {1, 0}  // epilogue
-  };
-  Gemm gemm_op;
-  CUTLASS_STATUS_CHECK(gemm_op(args));
-  return C;
-#else
-  TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
-  return at::Tensor{};
-#endif
-}
-
-template<
-  typename ElementC,
-  typename ThreadblockShape,
-  typename WarpShape,
-  typename InstructionShape,
-  int numStages>
-void scaled_int4_mm_cutlass_dispatch(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale, torch::Tensor C) {
-  // problem shape
-  int M = A.size(0);
-  int K = A.size(1) * 2;
-  int N = B.size(1);
-
-  constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;  // 8 for BF16/FP16
-  using ElementEpilogue = float;
-  constexpr int numEpilogueStages = 1;
-
-  // build epilogue visitor tree
-  using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
-    ThreadblockShape, WarpShape, ElementC, AlignmentC, numEpilogueStages
-  >;
-
-  using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
-  constexpr auto RoundMode = cutlass::FloatRoundStyle::round_to_nearest;
-  using Multiply = cutlass::epilogue::threadblock::VisitorCompute<
-    cutlass::multiplies, ElementEpilogue, ElementEpilogue, RoundMode
-  >;
-
-  // (1, N)
-  using ColScale = cutlass::epilogue::threadblock::VisitorRowBroadcast<
-    OutputTileThreadMap, ElementC,
-    cute::Stride<cute::_0, cute::_1, int32_t>  // MNL
-  >;
-  using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, Accum, ColScale>;
-
-  // (M, 1)
-  using RowScale = cutlass::epilogue::threadblock::VisitorColBroadcast<
-    OutputTileThreadMap, ElementC,
-    cute::Stride<cute::_1, cute::_0, int32_t>  // MNL
-  >;
-  using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, EVTCompute0, RowScale>;
-
-  using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
-    OutputTileThreadMap, ElementC, RoundMode,
-    cute::Stride<int64_t, cute::_1, int64_t>  // MNL
-  >;
-  using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<Output, EVTCompute1>;
-
-  using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
-    ElementA, cutlass::layout::RowMajor,    cutlass::ComplexTransform::kNone, AlignmentA,
-    ElementB, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentB,
-    ElementC, cutlass::layout::RowMajor,                                      AlignmentC,
-    ElementAccumulator, ElementEpilogue, OpClass, ArchTag,
-    ThreadblockShape, WarpShape, InstructionShape,
-    EVTOutput,
-    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
-    numStages,
-    cutlass::arch::OpMultiplyAddSaturate,  // OpMultiplyAdd does not work
-    numEpilogueStages
-  >::GemmKernel;
-  using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
-
-  // col_scale, row_scale, and C must have the same dtype
-  const ElementA *A_ptr         = reinterpret_cast<ElementA *>(A.data_ptr<int8_t>());
-  const ElementB *B_ptr         = reinterpret_cast<ElementB *>(B.data_ptr<int8_t>());
-  const ElementC *col_scale_ptr = reinterpret_cast<ElementC *>(col_scale.data_ptr());
-  const ElementC *row_scale_ptr = reinterpret_cast<ElementC *>(row_scale.data_ptr());
-  ElementC *C_ptr               = reinterpret_cast<ElementC *>(C.data_ptr());
-
-  typename EVTOutput::Arguments callback_args{
-    {
-      {
-        {},                                                                  // Accum
-        {col_scale_ptr, ElementC(0), {cute::_0{}, cute::_1{}, int32_t(N)}},  // ColScale
-        {}                                                                   // Multiply
-      },                                                                     // EVTCompute0
-      {row_scale_ptr, ElementC(0), {cute::_1{}, cute::_0{}, int32_t(M)}},    // RowScale
-      {}                                                                     // Multiply
-    },                                                                       // EVTCompute1
-    {C_ptr, {int64_t{N}, cute::_1{}, int64_t{M*N}}}                          // EVTOutput
-  };
-
-  typename DeviceGemm::Arguments args(
-    cutlass::gemm::GemmUniversalMode::kGemm,
-    cutlass::gemm::GemmCoord{M, N, K},
-    1,                              // batch_split
-    callback_args,
-    A_ptr, B_ptr, nullptr, nullptr, // unsued C_ptr and D_ptr
-    M * K, N * K, 0, 0,             // batch_stride A, B, C, D
-    K, K, 0, 0                      // stride A, B, C, D
-  );
-
-  DeviceGemm gemm_op;
-  auto stream = at::cuda::getCurrentCUDAStream();
-  CUTLASS_STATUS_CHECK(gemm_op.can_implement(args));
-  CUTLASS_STATUS_CHECK(gemm_op(args, nullptr, stream));
-}
-
-// we will do input checks in python. A and B are stored as int8
-// this function is based on the following cutlass example
-// https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
-// also with the help of emitted code from cutlass Python
-torch::Tensor scaled_int4_mm_cutlass(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale) {
-#if defined(BUILD_INT4_MM_CUTLASS)
-  int M = A.size(0);
-  int N = B.size(1);
-  torch::Tensor C = torch::empty({M, N}, row_scale.options());
-
-  // some configs for int4 mma
-  // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
-  // using default config. this can be tuned.
-  using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
-  using WarpShape        = cutlass::gemm::GemmShape<64, 64, 128>;
-  using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
-  constexpr int numStages = 3;
-
-  AT_DISPATCH_SWITCH(
-    row_scale.scalar_type(),
-    "scaled_int4_mm_cutlass",
-    AT_DISPATCH_CASE(
-      torch::ScalarType::Half,
-      [&]() {
-        using ElementC = cutlass::half_t;
-        scaled_int4_mm_cutlass_dispatch<
-          ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>(
-            A, B, row_scale, col_scale, C);
-      }
-    )
-    AT_DISPATCH_CASE(
-      torch::ScalarType::BFloat16,
-      [&]() {
-        using ElementC = cutlass::bfloat16_t;
-        scaled_int4_mm_cutlass_dispatch<
-          ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>(
-            A, B, row_scale, col_scale, C);
-      }
-    )
-  );
-
-  return C;
-#else
-  TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
-  return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
-  m.impl("torchao::int4_mm_cutlass", &int4_mm_cutlass);
-  m.impl("torchao::scaled_int4_mm_cutlass", &scaled_int4_mm_cutlass);
-}
-
-}  // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh
new file mode 100644
index 0000000..d1969bc
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh
@@ -0,0 +1,578 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/core/Tensor.h>
+#include <ATen/cuda/CUDAUtils.h>
+#include <c10/util/Exception.h>
+
+#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) &&                   \
+    defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
+#define BUILD_ROWWISE_SCALED_LINEAR_CUTLASS
+#endif
+
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+#include <cuda_runtime.h>
+#include <cutlass/cutlass.h>
+#include <cutlass/gemm/device/gemm_universal.h>
+#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
+#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
+#include <cutlass/gemm/device/gemm_universal_adapter.h>
+
+#define CUTLASS_STATUS_CHECK(status)                                      \
+  {                                                                       \
+    TORCH_CHECK(status == cutlass::Status::kSuccess,                      \
+                __func__, " : Got CUTLASS error: ",                       \
+                cutlassGetStatusString(status));                          \
+  }
+#endif
+
+namespace torchao {
+
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+template<
+    typename ThreadblockShape,
+    typename WarpShape,
+    typename InstructionShape,
+    typename ThreadblockSwizzle,
+    int NumStages,
+    typename ElementA,
+    typename ElementB,
+    typename ElementOutput,
+    typename ElementC,
+    typename UseTensorC,
+    typename ElementAScale,
+    typename ElementBScale>
+void rowwise_scaled_linear_kernel_cutlass_sm8x(
+    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+  static_assert((cutlass::sizeof_bits<ElementA>::value >= 8 ||
+                 8 % cutlass::sizeof_bits<ElementA>::value == 0) &&
+                (cutlass::sizeof_bits<ElementB>::value >= 8 ||
+                 8 % cutlass::sizeof_bits<ElementB>::value == 0));
+
+  using SmArch = cutlass::arch::Sm80;
+
+  using LayoutA = cutlass::layout::RowMajor;
+  using LayoutB = cutlass::layout::ColumnMajor;
+  using LayoutOutput = cutlass::layout::RowMajor;
+
+  using ElementAccumulator = int32_t;
+  using Operator =
+      std::conditional_t<std::is_same<ElementA, ElementB>::value,
+                         cutlass::arch::OpMultiplyAddSaturate,
+                         cutlass::arch::OpMultiplyAddMixedInputUpcast>;
+
+  using ElementEpilogue = float;
+
+  constexpr auto NumEVTEpilogueStages = 1;
+
+  const int m = tensor_a.size(0);
+  const int n = tensor_b.size(0);
+  int k = tensor_a.size(1);
+  if constexpr (cutlass::sizeof_bits<ElementA>::value < 8) {
+    k *= 8 % cutlass::sizeof_bits<ElementA>::value;
+  }
+
+  constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
+  constexpr int AlignmentAScale =
+      128 / cutlass::sizeof_bits<ElementAScale>::value;
+  constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
+  constexpr int AlignmentBScale =
+      128 / cutlass::sizeof_bits<ElementBScale>::value;
+  constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
+  constexpr int AlignmentOutput =
+      128 / cutlass::sizeof_bits<ElementOutput>::value;
+
+  // Check for current CUTLASS limitations w.r.t. alignments.
+  TORCH_CHECK(k % AlignmentA == 0,
+              __func__, " : Number of columns of tensor A must be divisible ",
+              "by ", AlignmentA);
+  TORCH_CHECK(k % AlignmentB == 0,
+              __func__, " : Number of columns of tensor B must be divisible ",
+              "by ", AlignmentB);
+  TORCH_CHECK(n % AlignmentC == 0,
+              __func__, " : Number of columns of tensor C must be divisible ",
+              "by ", AlignmentC);
+
+  using TensorAScaleTileThreadMap =
+      cutlass::epilogue::threadblock::OutputTileThreadLayout<
+          ThreadblockShape,
+          WarpShape,
+          ElementAScale,
+          AlignmentAScale,
+          NumEVTEpilogueStages>;
+  using TensorBScaleTileThreadMap =
+      cutlass::epilogue::threadblock::OutputTileThreadLayout<
+          ThreadblockShape,
+          WarpShape,
+          ElementBScale,
+          AlignmentBScale,
+          NumEVTEpilogueStages>;
+  using TensorCTileThreadMap =
+      cutlass::epilogue::threadblock::OutputTileThreadLayout<
+          ThreadblockShape,
+          WarpShape,
+          ElementC,
+          AlignmentC,
+          NumEVTEpilogueStages>;
+  using OutputTileThreadMap =
+      cutlass::epilogue::threadblock::OutputTileThreadLayout<
+          ThreadblockShape,
+          WarpShape,
+          ElementOutput,
+          AlignmentOutput,
+          NumEVTEpilogueStages>;
+
+  using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
+
+  using TensorAScale =
+      cutlass::epilogue::threadblock::VisitorColBroadcast<
+          TensorAScaleTileThreadMap,
+          ElementAScale,
+          cute::Stride<cute::_1, cute::_0, int64_t>>;
+  using TensorAScaleArguments = typename TensorAScale::Arguments;
+
+  using TensorBScale =
+      cutlass::epilogue::threadblock::VisitorRowBroadcast<
+          TensorBScaleTileThreadMap,
+          ElementBScale,
+          cute::Stride<cute::_0, cute::_1, int64_t>>;
+  using TensorBScaleArguments = typename TensorBScale::Arguments;
+
+  using TensorCScalar =
+      cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
+  using TensorCTensor =
+      cutlass::epilogue::threadblock::VisitorRowBroadcast<
+          TensorCTileThreadMap,
+          ElementC,
+          cute::Stride<cute::_0, cute::_1, int64_t>>;
+  using TensorC =
+      std::conditional_t<UseTensorC::value, TensorCTensor, TensorCScalar>;
+  using TensorCArguments = typename TensorC::Arguments;
+
+  using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::multiplies, ElementEpilogue, ElementEpilogue,
+      cutlass::FloatRoundStyle::round_to_nearest
+  >;
+  using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT<
+      ApplyAScale,
+      Accum,
+      TensorAScale>;
+
+  using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::multiplies, ElementEpilogue, ElementEpilogue,
+      cutlass::FloatRoundStyle::round_to_nearest
+  >;
+  using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT<
+      ApplyBScale,
+      EVTApplyAScale,
+      TensorBScale>;
+
+  using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::plus, ElementEpilogue, ElementEpilogue,
+      cutlass::FloatRoundStyle::round_to_nearest
+  >;
+  using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
+      ApplySum,
+      EVTApplyBScale,
+      TensorC>;
+
+  using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
+      OutputTileThreadMap, ElementOutput,
+      cutlass::FloatRoundStyle::round_to_nearest,
+      cute::Stride<int64_t, cute::_1, int64_t> // StrideMNL
+  >;
+
+  using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
+      Output,
+      EVTApplySum>;
+
+  using EVTKernel =
+      typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
+      ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
+      ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
+      ElementOutput, LayoutOutput, AlignmentOutput,
+      ElementAccumulator,
+      ElementEpilogue,
+      cutlass::arch::OpClassTensorOp,
+      SmArch,
+      ThreadblockShape,
+      WarpShape,
+      InstructionShape,
+      EVTOutput,
+      ThreadblockSwizzle,
+      NumStages,
+      Operator,
+      NumEVTEpilogueStages
+  >::GemmKernel;
+
+  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
+
+  cutlass::gemm::GemmCoord problem_size(m, n, k);
+  constexpr auto SplitKFactor = 1;
+
+  TensorAScaleArguments tensor_a_scale_arguments{
+    (ElementAScale*)tensor_a_scale.data_ptr(),
+    ElementAScale(1),
+    {cute::_1{}, cute::_0{}, problem_size.m()}
+  };
+  TensorBScaleArguments tensor_b_scale_arguments{
+    (ElementBScale*)tensor_b_scale.data_ptr(),
+    ElementBScale(1),
+    {cute::_0{}, cute::_1{}, problem_size.n()}
+  };
+  TensorCArguments tensor_c_arguments{
+    [&]() -> TensorCArguments {
+      if constexpr (UseTensorC::value) {
+        return {(ElementC*)tensor_c.data_ptr(),
+                ElementC(0),
+                {cute::_0{}, cute::_1{}, problem_size.n()}};
+      } else {
+        return {ElementC(0)};
+      }
+    }()
+  };
+  typename Output::Arguments output_arguments{
+    (ElementOutput*)tensor_d.data_ptr(),
+    {problem_size.n(), cute::_1{}, problem_size.mn().product()}
+  };
+  typename EVTOutput::Arguments callback_arguments{
+    {
+      {
+        {
+          {},                        // Accum
+          tensor_a_scale_arguments,  // TensorAScale
+          {}                         // ApplyAScale
+        },                           // EVTApplyAScale
+        tensor_b_scale_arguments,    // TensorBScale
+        {},                          // ApplyBScale
+      },                             // EVTApplyBScale
+      tensor_c_arguments,            // TensorC
+      {}                             // ApplySum
+    },                               // EVTApplySum
+    output_arguments                 // Output
+  };                                 // EVTOutput
+
+  typename Gemm::Arguments arguments(
+    cutlass::gemm::GemmUniversalMode::kGemm,
+    problem_size,
+    SplitKFactor,
+    callback_arguments,              // arguments of EVT callbacks
+    (ElementA*)tensor_a.data_ptr(),
+    (ElementB*)tensor_b.data_ptr(),
+    nullptr,                         // ptr C (unused)
+    nullptr,                         // ptr D (unused)
+    problem_size.mk().product(),     // batch stride A
+    problem_size.nk().product(),     // batch stride B
+    0,                               // batch stride C (unused)
+    0,                               // batch stride D (unused)
+    problem_size.k(),                // stride A
+    problem_size.k(),                // stride B
+    0,                               // stride C (unused)
+    0                                // stride D (unused)
+  );
+
+  Gemm gemm_op;
+
+  cutlass::Status status;
+
+  // Verify that GEMM operation with given arguments can be performed
+  // by CUTLASS.
+  status = gemm_op.can_implement(arguments);
+  CUTLASS_STATUS_CHECK(status);
+
+  // Allocate workspace for CUTLASS mixed datatypes GEMM kernel.
+  const auto workspace_size = Gemm::get_workspace_size(arguments);
+  auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
+                                      at::TensorOptions().dtype(at::kByte));
+
+  // Initialize CUTLASS mixed datatypes GEMM object.
+  status = gemm_op.initialize(arguments, workspace.data_ptr(),
+                              at::cuda::getCurrentCUDAStream());
+  CUTLASS_STATUS_CHECK(status);
+
+  // Perform mixed datatypes GEMM operation.
+  status = gemm_op.run(at::cuda::getCurrentCUDAStream());
+  CUTLASS_STATUS_CHECK(status);
+
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template<typename ElementA, typename ElementB, typename... Types>
+static void select_config(
+    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+    const at::Tensor& tensor_c, at::Tensor& tensor_d) { 
+  const auto dprops = at::cuda::getCurrentDeviceProperties();
+  const auto is_sm8x = dprops->major == 8;
+
+  if (is_sm8x) {
+    if constexpr (std::is_same<ElementA, cutlass::int4b_t>::value &&
+                  std::is_same<ElementB, cutlass::int4b_t>::value) {
+      using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
+      using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
+      using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
+      using ThreadblockSwizzle =
+          cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
+      constexpr auto NumStages = 3;
+      rowwise_scaled_linear_kernel_cutlass_sm8x<
+        ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+          NumStages, ElementA, ElementB, Types...>(
+            tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+            tensor_d);
+      return;
+    } else if constexpr (std::is_same<ElementA, int8_t>::value &&
+                  std::is_same<ElementB, cutlass::int4b_t>::value) {
+      using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
+      using ThreadblockSwizzle =
+        cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
+
+      // A minimal heuristic to improve performance for small number
+      // of inputs cases.
+      if (tensor_a.size(0) <= 16) {
+        using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>;
+        using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>;
+        constexpr auto NumStages = 6;
+        rowwise_scaled_linear_kernel_cutlass_sm8x<
+          ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+          NumStages, ElementA, ElementB, Types...>(
+              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+              tensor_d);
+      } else if (tensor_a.size(0) <= 32) {
+        using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
+        using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
+        constexpr auto NumStages = 5;
+        rowwise_scaled_linear_kernel_cutlass_sm8x<
+          ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+          NumStages, ElementA, ElementB, Types...>(
+              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+              tensor_d);
+      } else {
+        using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>;
+        using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>;
+        constexpr auto NumStages = 4;
+        rowwise_scaled_linear_kernel_cutlass_sm8x<
+          ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+          NumStages, ElementA, ElementB, Types...>(
+              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+              tensor_d);
+      }
+      return;
+    }
+  }
+
+  TORCH_CHECK(false,
+              __func__, " : Operator not supported on SM", dprops->major, ".",
+              dprops->minor, " for given operands");
+}
+
+template<
+   typename ElementA,
+   typename ElementB,
+   typename ElementOutput,
+   typename... Types>
+static void
+dispatch_on_tensor_c(
+    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+  if (tensor_c.numel() == 0) {
+    using ElementC = ElementOutput;
+    using UseTensorC = std::false_type;
+    select_config<
+      ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+          tensor_d);
+    return;
+  }
+
+  using UseTensorC = std::true_type;
+  if (tensor_c.scalar_type() == at::ScalarType::Half) {
+    using ElementC = cutlass::half_t;
+    select_config<
+      ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+          tensor_d);
+    return;
+  } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
+    using ElementC = cutlass::bfloat16_t;
+    select_config<
+      ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+          tensor_d);
+    return;
+  }
+
+  TORCH_CHECK(false,
+              __func__, " : Operator not supported for datatype ",
+                tensor_c.scalar_type(), " for addend");
+}
+
+template<typename ElementA, typename ElementB>
+static void
+dispatch_on_tensor_a_scale_and_tensor_b_scale(
+    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+  TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
+              __func__, " : Operator not supported for output datatype ",
+              tensor_d.scalar_type(), " as it's different from the first ",
+              " operand scale datatype ", tensor_a_scale.scalar_type());
+
+  if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
+      tensor_b_scale.scalar_type() == at::ScalarType::Half) {
+    using ElementAScale = cutlass::half_t;
+    using ElementBScale = cutlass::half_t;
+    using ElementOutput = cutlass::half_t;
+    dispatch_on_tensor_c<ElementA, ElementB, ElementOutput, ElementAScale,
+      ElementBScale>(
+        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
+    return;
+  } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
+             tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
+    using ElementAScale = cutlass::bfloat16_t;
+    using ElementBScale = cutlass::bfloat16_t;
+    using ElementOutput = cutlass::bfloat16_t;
+    dispatch_on_tensor_c<ElementA, ElementB, ElementOutput, ElementAScale,
+      ElementBScale>(
+        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
+    return;
+  }
+
+  TORCH_CHECK(false,
+              __func__, " : Operator not supported for combination of data ",
+              "types ", tensor_a_scale.scalar_type(),
+              " for first operand scale and ", tensor_b_scale.scalar_type(),
+                " for second operand scale");
+}
+
+template<typename ElementA, typename ElementB>
+void
+rowwise_scaled_linear_cutlass_check_inputs(
+    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+    const at::Tensor& w_scale, const at::Tensor& bias) {
+  // Validate layouts of arguments.
+  TORCH_CHECK(xq.dim() >= 2,
+              __func__, " : Expected xq argument to be 2D or "
+              "higher-dimensional tensor, got ", xq.dim(), " dims");
+  TORCH_CHECK(xq.layout() == at::Layout::Strided,
+              __func__, " : Expected xq argument to be strided, got layout ",
+              xq.layout());
+  TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
+              __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
+              "D tensor, got ", x_scale.dim(), " dims");
+  TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
+              __func__, " : Expected xq scale argument to be strided, got "
+              "layout ", x_scale.layout());
+  TORCH_CHECK(wq.dim() == 2,
+              __func__, " : Expected wq argument to be 2D tensor, got ",
+              wq.dim(), " dims");
+  TORCH_CHECK(wq.layout() == at::Layout::Strided,
+              __func__, " : Expected wq argument to be strided, got layout ",
+              wq.layout());
+  TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
+              __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
+              "got ", w_scale.dim(), " dims");
+  TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
+              __func__, " : Expected wq scale argument to be strided, got "
+              "layout ", w_scale.layout());
+  if (bias.numel() > 0) {
+    TORCH_CHECK(bias.dim() == 1,
+                __func__, " : Expected bias argument to be 1D tensor, got ",
+                bias.dim(), " dims");
+    TORCH_CHECK(bias.layout() == at::Layout::Strided,
+                __func__, " : Expected bias argument to be strided, got ",
+                "layout ", bias.layout());
+  }
+
+  // Validate sizes of arguments.
+  const auto xq_sizes = xq.sizes().vec();
+  TORCH_CHECK(xq_sizes.back() == wq.size(1) ||
+              xq_sizes.back() == 2 * wq.size(1),
+              __func__, " : Expected xq argument to have ", wq.size(1), " or ",
+              2 * wq.size(1), " columns, but got ", xq_sizes.back());
+  const auto x_scale_sizes = x_scale.sizes().vec();
+  for (auto i = 0; i < x_scale_sizes.size(); ++i)
+    TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
+                __func__, " : Expected xq scale argument size at position ",
+                i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
+  TORCH_CHECK(w_scale.numel() == wq.size(0),
+              __func__, " : Expected wq scale argument to have ", wq.size(0),
+              " elements, got ", w_scale.numel(), " elements");
+  if (bias.numel() > 0) {
+    TORCH_CHECK(bias.numel() == wq.size(0),
+                __func__, " : Expected bias argument to have ", wq.size(0),
+                " elements, got ", bias.numel(), " elements");
+  }
+
+  // Validate strides of arguments.
+  const auto xq_strides = xq.strides();
+  TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
+              __func__, " : Expected xq argument in row-major layout");
+  auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
+  for (int i = xq_strides.size() - 3; i >= 0; --i) {
+    xq_stride_expected *= xq_sizes[i + 1];
+    TORCH_CHECK(xq_strides[i] == xq_stride_expected,
+                __func__, " : Expected xq argument in row-major layout");
+  }
+  TORCH_CHECK(x_scale.is_contiguous(),
+              __func__, " : Expected xq scale argument to be contiguous");
+  const auto wq_strides = wq.strides();
+  TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
+              __func__, " : Expected wq argument in row-major layout");
+  TORCH_CHECK(w_scale.is_contiguous(),
+              __func__, " : Expected wq scale argument to be contiguous");
+  if (bias.numel() > 0) {
+    const auto bias_strides = bias.strides();
+    TORCH_CHECK(bias_strides[0] == 1,
+                __func__, " : Expected bias argument to be contiguous");
+  }
+}
+#endif
+
+// Perform linear operation, using corresponding CUTLASS datatypes
+// GEMM kernel, to given arguments - result produced is:
+// (tensor_a * tensor_a_scale) @ (tensor_b * tensor_b_scale).T + tensor_c
+//
+// Notes: The "tensor_a" and "tensor_b" are expected to be 2D tensors.
+// The "tensor_a_scale" tensor is expected to be a vector, of size
+// equal to number of rows of "tensor_a" tensor.  The "tensor_b_scale"
+// tensor is expected to be a vector, of size equal to number of rows
+// of "tensor_b" tensor. The "tensor_c" tensor is expected to be a
+// vector, of size equal to number of rows of "tensor_b" tensor.
+template <typename ElementA, typename ElementB>
+at::Tensor
+rowwise_scaled_linear_cutlass(
+    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+    const at::Tensor& w_scale, const at::Tensor& bias) {
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+  // Check inputs.
+  rowwise_scaled_linear_cutlass_check_inputs<ElementA, ElementB>(
+      xq, x_scale, wq, w_scale, bias);
+
+  // Squash the input tensors as appropriate.
+  const auto xq_sizes = xq.sizes().vec();
+  const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
+  const auto x_scale_1d = x_scale.reshape({-1});
+  const auto w_scale_1d = w_scale.reshape({-1});
+
+  // Create result tensor.
+  at::Tensor result =
+      x_scale.new_empty({xq_2d.size(0), wq.size(0)});
+
+  // Dispatch to appropriate kernel template.
+  dispatch_on_tensor_a_scale_and_tensor_b_scale<ElementA, ElementB>(
+      xq_2d, x_scale_1d, wq, w_scale_1d, bias, result);
+
+  // Reshape and return result tensor.
+  auto result_sizes = xq_sizes;
+  result_sizes.back() = wq.size(0);
+  return result.reshape(result_sizes);
+#else
+  TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
+  return at::Tensor{};
+#endif
+}
+
+}  // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
new file mode 100644
index 0000000..9a64b2b
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
@@ -0,0 +1,28 @@
+#include <torch/extension.h>
+
+#include "rowwise_scaled_linear_cutlass.cuh"
+
+namespace torchao {
+
+at::Tensor
+rowwise_scaled_linear_cutlass_s4s4(
+    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+    const at::Tensor& w_scale, const at::Tensor& bias) {
+  // Validate input datatypes.
+  TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar,
+              __func__, " : The input datatypes combination ", xq.dtype(),
+              " for xq and ", wq.dtype(), " for wq is not supported");
+
+  // Dispatch to appropriate kernel template.
+  using ElementA = cutlass::int4b_t;
+  using ElementB = cutlass::int4b_t;
+  return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
+      xq, x_scale, wq, w_scale, bias);
+}
+
+TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
+  m.impl("torchao::rowwise_scaled_linear_cutlass_s4s4",
+         &rowwise_scaled_linear_cutlass_s4s4);
+}
+
+}  // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
new file mode 100644
index 0000000..752c557
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
@@ -0,0 +1,28 @@
+#include <torch/extension.h>
+
+#include "rowwise_scaled_linear_cutlass.cuh"
+
+namespace torchao {
+
+at::Tensor
+rowwise_scaled_linear_cutlass_s8s4(
+    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+    const at::Tensor& w_scale, const at::Tensor& bias) {
+  // Validate input datatypes.
+  TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar,
+              __func__, " : The input datatypes combination ", xq.dtype(),
+              " for xq and ", wq.dtype(), " for wq is not supported");
+
+  // Dispatch to appropriate kernel template.
+  using ElementA = int8_t;
+  using ElementB = cutlass::int4b_t;
+  return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
+      xq, x_scale, wq, w_scale, bias);
+}
+
+TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
+  m.impl("torchao::rowwise_scaled_linear_cutlass_s8s4",
+         &rowwise_scaled_linear_cutlass_s8s4);
+}
+
+}  // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu
deleted file mode 100644
index 8faf13c..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu
+++ /dev/null
@@ -1,268 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) &&                   \
-    defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_S4S4_LINEAR_CUTLASS
-#endif
-
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
-#include "scaled_linear.h"
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
-
-template<typename... Types>
-static void select_config(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  const auto dprops = at::cuda::getCurrentDeviceProperties();
-  const auto is_sm8x = dprops->major >= 8;
-
-  if (is_sm8x) {
-    using ElementA = cutlass::int4b_t;
-    using ElementB = cutlass::int4b_t;
-    using ElementAccumulator = int32_t;
-
-    using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
-    using WarpShape        = cutlass::gemm::GemmShape<64, 64, 128>;
-    using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
-    constexpr auto NumStages = 3;
-    using Operator = cutlass::arch::OpMultiplyAddSaturate;
-    // using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast;  // this does not work
-    using ThreadblockSwizzle =
-      cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
-
-    scaled_linear_kernel_cutlass_sm8x<
-      ThreadblockShape, WarpShape, InstructionShape, NumStages,
-      ThreadblockSwizzle, ElementA, ElementB, ElementAccumulator, Operator,
-      Types...>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported on SM", dprops->major, ".",
-              dprops->minor, " for given operands");
-}
-
-template<typename ElementAScale, typename ElementBScale, typename ElementOutput>
-static void
-dispatch_on_tensor_c(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  if (tensor_c.numel() == 0) {
-    using ElementC = ElementOutput;
-    using UseTensorC = std::false_type;
-    select_config<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  }
-
-  using UseTensorC = std::true_type;
-  if (tensor_c.scalar_type() == at::ScalarType::Half) {
-    using ElementC = cutlass::half_t;
-    select_config<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
-    using ElementC = cutlass::bfloat16_t;
-    select_config<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported for datatype ",
-                tensor_c.scalar_type(), " for addend");
-}
-
-static void
-dispatch_on_tensor_a_scale_and_tensor_b_scale(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
-              __func__, " : Operator not supported for output datatype ",
-              tensor_d.scalar_type(), " as it's different from the first ",
-              " operand scale datatype ", tensor_a_scale.scalar_type());
-
-  if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
-      tensor_b_scale.scalar_type() == at::ScalarType::Half) {
-    using ElementAScale = cutlass::half_t;
-    using ElementBScale = cutlass::half_t;
-    using ElementOutput = cutlass::half_t;
-    dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
-        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-    return;
-  } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
-             tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
-    using ElementAScale = cutlass::bfloat16_t;
-    using ElementBScale = cutlass::bfloat16_t;
-    using ElementOutput = cutlass::bfloat16_t;
-    dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
-        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-    return;
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported for combination of data ",
-              "types ", tensor_a_scale.scalar_type(),
-              " for first operand scale and ", tensor_b_scale.scalar_type(),
-                " for second operand scale");
-}
-
-static void
-check_inputs(
-    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
-    const at::Tensor& w_scale, const at::Tensor& bias) {
-  // Validate layouts of arguments.
-  TORCH_CHECK(xq.dim() >= 2,
-              __func__, " : Expected xq argument to be 2D or "
-              "higher-dimensional tensor, got ", xq.dim(), " dims");
-  TORCH_CHECK(xq.layout() == at::Layout::Strided,
-              __func__, " : Expected xq argument to be strided, got layout ",
-              xq.layout());
-  TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
-              __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
-              "D tensor, got ", x_scale.dim(), " dims");
-  TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
-              __func__, " : Expected xq scale argument to be strided, got "
-              "layout ", x_scale.layout());
-  TORCH_CHECK(wq.dim() == 2,
-              __func__, " : Expected wq argument to be 2D tensor, got ",
-              wq.dim(), " dims");
-  TORCH_CHECK(wq.layout() == at::Layout::Strided,
-              __func__, " : Expected wq argument to be strided, got layout ",
-              wq.layout());
-  TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
-              __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
-              "got ", w_scale.dim(), " dims");
-  TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
-              __func__, " : Expected wq scale argument to be strided, got "
-              "layout ", w_scale.layout());
-  if (bias.numel() > 0) {
-    TORCH_CHECK(bias.dim() == 1,
-                __func__, " : Expected bias argument to be 1D tensor, got ",
-                bias.dim(), " dims");
-    TORCH_CHECK(bias.layout() == at::Layout::Strided,
-                __func__, " : Expected bias argument to be strided, got ",
-                "layout ", bias.layout());
-  }
-
-  // Validate sizes of arguments.
-  const auto xq_sizes = xq.sizes().vec();
-  TORCH_CHECK(xq_sizes.back() == wq.size(1),
-              __func__, " : Expected xq argument to have ", wq.size(1),
-              " columns, but got ", xq_sizes.back());
-  const auto x_scale_sizes = x_scale.sizes().vec();
-  for (auto i = 0; i < x_scale_sizes.size(); ++i)
-    TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
-                __func__, " : Expected xq scale argument size at position ",
-                i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
-  TORCH_CHECK(w_scale.numel() == wq.size(0),
-              __func__, " : Expected wq scale argument to have ", wq.size(0),
-              " elements, got ", w_scale.numel(), " elements");
-  if (bias.numel() > 0) {
-    TORCH_CHECK(bias.numel() == wq.size(0),
-                __func__, " : Expected bias argument to have ", wq.size(0),
-                " elements, got ", bias.numel(), " elements");
-  }
-
-  // Validate strides of arguments.
-  const auto xq_strides = xq.strides();
-  TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
-              __func__, " : Expected xq argument in row-major layout");
-  auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
-  for (int i = xq_strides.size() - 3; i >= 0; --i) {
-    xq_stride_expected *= xq_sizes[i + 1];
-    TORCH_CHECK(xq_strides[i] == xq_stride_expected,
-                __func__, " : Expected xq argument in row-major layout");
-  }
-  TORCH_CHECK(x_scale.is_contiguous(),
-              __func__, " : Expected xq scale argument to be contiguous");
-  const auto wq_strides = wq.strides();
-  TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
-              __func__, " : Expected wq argument in row-major layout");
-  TORCH_CHECK(w_scale.is_contiguous(),
-              __func__, " : Expected wq scale argument to be contiguous");
-  if (bias.numel() > 0) {
-    const auto bias_strides = bias.strides();
-    TORCH_CHECK(bias_strides[0] == 1,
-                __func__, " : Expected bias argument to be contiguous");
-  }
-}
-#endif
-
-// Perform linear operation, using corresponding CUTLASS mixed
-// data-types GEMM kernel, to given arguments:
-//   result = (xq * x_scale) @ (wq * w_scale).T + bias
-// Notes: The "x_scale" tensor is expected to be a vector, of size
-// equal to number of rows of "xq" tensor.  The "w_scale" tensor is
-// expected to be a vector, of size equal to number of rows of "wq"
-// tensor. The "bias" tensor is expected to be a vector, of size equal
-// to number of rows of "wq" tensor.
-at::Tensor
-s4s4_linear_cutlass(
-    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
-    const at::Tensor& w_scale, const at::Tensor& bias) {
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
-  // Check inputs.
-  check_inputs(xq, x_scale, wq, w_scale, bias);
-
-  // Squash the input tensors as appropriate.
-  const auto xq_sizes = xq.sizes().vec();
-  const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
-  const auto x_scale_sizes = x_scale.sizes().vec();
-  const auto x_scale_1d = x_scale.reshape({-1});
-  const auto w_scale_1d = w_scale.reshape({-1});
-
-  // Introduce alias names for arguments, according to the CUTLASS
-  // naming conventions.
-  const auto& tensor_a = xq_2d;
-  const auto& tensor_a_scale = x_scale_1d;
-  const auto& tensor_b = wq;
-  const auto& tensor_b_scale = w_scale_1d;
-  const auto& tensor_c = bias;
-
-  // Create output tensor.
-  at::Tensor tensor_d =
-      tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)});
-
-  // Dispatch to appropriate kernel template.
-  dispatch_on_tensor_a_scale_and_tensor_b_scale(
-      tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-
-  // Reshape and return output tensor.
-  auto tensor_d_sizes = xq_sizes;
-  tensor_d_sizes.back() = wq.size(0);
-  return tensor_d.reshape(tensor_d_sizes);
-#else
-  TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
-  return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
-  m.impl("torchao::s4s4_linear_cutlass", &s4s4_linear_cutlass);
-}
-
-}  // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
deleted file mode 100644
index 53eaf53..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
+++ /dev/null
@@ -1,315 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) &&                   \
-    defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_S8S4_LINEAR_CUTLASS
-#endif
-
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
-#include "scaled_linear.h"
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
-
-template<typename ElementA, typename ElementB, typename... Types>
-static void select_config(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  const auto dprops = at::cuda::getCurrentDeviceProperties();
-  const auto is_sm8x = dprops->major == 8;
-
-  if (is_sm8x) {
-    if constexpr (std::is_same<ElementA, int8_t>::value &&
-                  std::is_same<ElementB, cutlass::int4b_t>::value) {
-      using ThreadblockSwizzle =
-        cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
-      using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
-
-      // A minimal heuristic to improve performance for small number
-      // of inputs cases.
-      if (tensor_a.size(0) <= 16) {
-        using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>;
-        using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>;
-        constexpr auto NumStages = 6;
-        scaled_linear_kernel_cutlass_sm8x<
-          ThreadblockShape, WarpShape, InstructionShape, NumStages,
-          ThreadblockSwizzle, ElementA, ElementB, Types...>(
-              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-              tensor_d);
-      } else if (tensor_a.size(0) <= 32) {
-        using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
-        using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
-        constexpr auto NumStages = 5;
-        scaled_linear_kernel_cutlass_sm8x<
-          ThreadblockShape, WarpShape, InstructionShape, NumStages,
-          ThreadblockSwizzle, ElementA, ElementB, Types...>(
-              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-              tensor_d);
-      } else {
-        using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>;
-        using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>;
-        constexpr auto NumStages = 4;
-        scaled_linear_kernel_cutlass_sm8x<
-          ThreadblockShape, WarpShape, InstructionShape, NumStages,
-          ThreadblockSwizzle, ElementA, ElementB, Types...>(
-              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-              tensor_d);
-      }
-      return;
-    }
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported on SM", dprops->major, ".",
-              dprops->minor, " for given operands");
-}
-
-template<typename... Types>
-static void
-dispatch_on_tensor_a_and_tensor_b(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  if (tensor_a.scalar_type() == at::ScalarType::Char) {
-    if (tensor_b.scalar_type() == at::ScalarType::Char) {
-      if (tensor_a.size(1) == 2 * tensor_b.size(1)) {
-        using ElementA = int8_t;
-        using ElementB = cutlass::int4b_t;
-        using ElementAccumulator = int32_t;
-        using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast;
-        select_config<
-          ElementA, ElementB, ElementAccumulator, Operator, Types...>(
-            tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-            tensor_d);
-      }
-      return;
-    }
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported for combination of data ",
-              "types ", tensor_a.scalar_type(), " for first operand and ",
-              tensor_b.scalar_type(), " for second operand");
-}
-
-
-template<typename ElementAScale, typename ElementBScale, typename ElementOutput>
-static void
-dispatch_on_tensor_c(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  if (tensor_c.numel() == 0) {
-    using ElementC = ElementOutput;
-    using UseTensorC = std::false_type;
-    dispatch_on_tensor_a_and_tensor_b<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  }
-
-  using UseTensorC = std::true_type;
-  if (tensor_c.scalar_type() == at::ScalarType::Half) {
-    using ElementC = cutlass::half_t;
-    dispatch_on_tensor_a_and_tensor_b<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
-    using ElementC = cutlass::bfloat16_t;
-    dispatch_on_tensor_a_and_tensor_b<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported for datatype ",
-                tensor_c.scalar_type(), " for addend");
-}
-
-static void
-dispatch_on_tensor_a_scale_and_tensor_b_scale(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
-              __func__, " : Operator not supported for output datatype ",
-              tensor_d.scalar_type(), " as it's different from the first ",
-              " operand scale datatype ", tensor_a_scale.scalar_type());
-
-  if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
-      tensor_b_scale.scalar_type() == at::ScalarType::Half) {
-    using ElementAScale = cutlass::half_t;
-    using ElementBScale = cutlass::half_t;
-    using ElementOutput = cutlass::half_t;
-    dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
-        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-    return;
-  } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
-             tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
-    using ElementAScale = cutlass::bfloat16_t;
-    using ElementBScale = cutlass::bfloat16_t;
-    using ElementOutput = cutlass::bfloat16_t;
-    dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
-        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-    return;
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported for combination of data ",
-              "types ", tensor_a_scale.scalar_type(),
-              " for first operand scale and ", tensor_b_scale.scalar_type(),
-                " for second operand scale");
-}
-
-static void
-check_inputs(
-    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
-    const at::Tensor& w_scale, const at::Tensor& bias) {
-  // Validate layouts of arguments.
-  TORCH_CHECK(xq.dim() >= 2,
-              __func__, " : Expected xq argument to be 2D or "
-              "higher-dimensional tensor, got ", xq.dim(), " dims");
-  TORCH_CHECK(xq.layout() == at::Layout::Strided,
-              __func__, " : Expected xq argument to be strided, got layout ",
-              xq.layout());
-  TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
-              __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
-              "D tensor, got ", x_scale.dim(), " dims");
-  TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
-              __func__, " : Expected xq scale argument to be strided, got "
-              "layout ", x_scale.layout());
-  TORCH_CHECK(wq.dim() == 2,
-              __func__, " : Expected wq argument to be 2D tensor, got ",
-              wq.dim(), " dims");
-  TORCH_CHECK(wq.layout() == at::Layout::Strided,
-              __func__, " : Expected wq argument to be strided, got layout ",
-              wq.layout());
-  TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
-              __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
-              "got ", w_scale.dim(), " dims");
-  TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
-              __func__, " : Expected wq scale argument to be strided, got "
-              "layout ", w_scale.layout());
-  if (bias.numel() > 0) {
-    TORCH_CHECK(bias.dim() == 1,
-                __func__, " : Expected bias argument to be 1D tensor, got ",
-                bias.dim(), " dims");
-    TORCH_CHECK(bias.layout() == at::Layout::Strided,
-                __func__, " : Expected bias argument to be strided, got ",
-                "layout ", bias.layout());
-  }
-
-  // Validate sizes of arguments.
-  const auto xq_sizes = xq.sizes().vec();
-  TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1),
-              __func__, " : Expected xq argument to have ", 2 * wq.size(1),
-              " columns, but got ", xq_sizes.back());
-  const auto x_scale_sizes = x_scale.sizes().vec();
-  for (auto i = 0; i < x_scale_sizes.size(); ++i)
-    TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
-                __func__, " : Expected xq scale argument size at position ",
-                i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
-  TORCH_CHECK(w_scale.numel() == wq.size(0),
-              __func__, " : Expected wq scale argument to have ", wq.size(0),
-              " elements, got ", w_scale.numel(), " elements");
-  if (bias.numel() > 0) {
-    TORCH_CHECK(bias.numel() == wq.size(0),
-                __func__, " : Expected bias argument to have ", wq.size(0),
-                " elements, got ", bias.numel(), " elements");
-  }
-
-  // Validate strides of arguments.
-  const auto xq_strides = xq.strides();
-  TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
-              __func__, " : Expected xq argument in row-major layout");
-  auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
-  for (int i = xq_strides.size() - 3; i >= 0; --i) {
-    xq_stride_expected *= xq_sizes[i + 1];
-    TORCH_CHECK(xq_strides[i] == xq_stride_expected,
-                __func__, " : Expected xq argument in row-major layout");
-  }
-  TORCH_CHECK(x_scale.is_contiguous(),
-              __func__, " : Expected xq scale argument to be contiguous");
-  const auto wq_strides = wq.strides();
-  TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
-              __func__, " : Expected wq argument in row-major layout");
-  TORCH_CHECK(w_scale.is_contiguous(),
-              __func__, " : Expected wq scale argument to be contiguous");
-  if (bias.numel() > 0) {
-    const auto bias_strides = bias.strides();
-    TORCH_CHECK(bias_strides[0] == 1,
-                __func__, " : Expected bias argument to be contiguous");
-  }
-}
-#endif
-
-// Perform linear operation, using corresponding CUTLASS mixed
-// data-types GEMM kernel, to given arguments:
-//   result = (xq * x_scale) @ (wq * w_scale).T + bias
-// Notes: The "x_scale" tensor is expected to be a vector, of size
-// equal to number of rows of "xq" tensor.  The "w_scale" tensor is
-// expected to be a vector, of size equal to number of rows of "wq"
-// tensor. The "bias" tensor is expected to be a vector, of size equal
-// to number of rows of "wq" tensor.
-at::Tensor
-s8s4_linear_cutlass(
-    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
-    const at::Tensor& w_scale, const at::Tensor& bias) {
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
-  // Check inputs.
-  check_inputs(xq, x_scale, wq, w_scale, bias);
-
-  // Squash the input tensors as appropriate.
-  const auto xq_sizes = xq.sizes().vec();
-  const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
-  const auto x_scale_sizes = x_scale.sizes().vec();
-  const auto x_scale_1d = x_scale.reshape({-1});
-  const auto w_scale_1d = w_scale.reshape({-1});
-
-  // Introduce alias names for arguments, according to the CUTLASS
-  // naming conventions.
-  const auto& tensor_a = xq_2d;
-  const auto& tensor_a_scale = x_scale_1d;
-  const auto& tensor_b = wq;
-  const auto& tensor_b_scale = w_scale_1d;
-  const auto& tensor_c = bias;
-
-  // Create output tensor.
-  at::Tensor tensor_d =
-      tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)});
-
-  // Dispatch to appropriate kernel template.
-  dispatch_on_tensor_a_scale_and_tensor_b_scale(
-      tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-
-  // Reshape and return output tensor.
-  auto tensor_d_sizes = xq_sizes;
-  tensor_d_sizes.back() = wq.size(0);
-  return tensor_d.reshape(tensor_d_sizes);
-#else
-  TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
-  return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
-  m.impl("torchao::s8s4_linear_cutlass", &s8s4_linear_cutlass);
-}
-
-}  // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h b/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h
deleted file mode 100644
index 991384b..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h
+++ /dev/null
@@ -1,288 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
-#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
-#include "cutlass/gemm/device/gemm_universal_adapter.h"
-
-#define CUTLASS_STATUS_CHECK(status)                                      \
-  {                                                                       \
-    TORCH_CHECK(status == cutlass::Status::kSuccess,                      \
-                __func__, " : Got CUTLASS error: ",                       \
-                cutlassGetStatusString(status));                          \
-  }
-
-namespace torchao {
-
-template<
-    typename ThreadblockShape,
-    typename WarpShape,
-    typename InstructionShape,
-    int NumStages,
-    typename ThreadblockSwizzle,
-    typename ElementA,
-    typename ElementB,
-    typename ElementAccumulator,
-    typename Operator,
-    typename ElementAScale,
-    typename ElementBScale,
-    typename ElementC,
-    typename UseTensorC,
-    typename ElementOutput>
-void scaled_linear_kernel_cutlass_sm8x(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  using SmArch = cutlass::arch::Sm80;
-
-  using LayoutA = cutlass::layout::RowMajor;
-  using LayoutB = cutlass::layout::ColumnMajor;
-  using LayoutOutput = cutlass::layout::RowMajor;
-
-  using ElementEpilogue = float;
-  constexpr auto NumEVTEpilogueStages = 1;
-
-  const int m = tensor_a.size(0);
-  const int n = tensor_b.size(0);
-  const int k = std::is_same<ElementA, cutlass::int4b_t>::value ?
-                tensor_a.size(1) * 2 :
-                tensor_a.size(1);
-
-  constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
-  constexpr int AlignmentAScale =
-      128 / cutlass::sizeof_bits<ElementAScale>::value;
-  constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
-  constexpr int AlignmentBScale =
-      128 / cutlass::sizeof_bits<ElementBScale>::value;
-  constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
-  constexpr int AlignmentOutput =
-      128 / cutlass::sizeof_bits<ElementOutput>::value;
-
-  // Check for current CUTLASS limitations w.r.t. alignments.
-  TORCH_CHECK(k % AlignmentA == 0,
-              __func__, " : Number of columns of tensor A must be divisible ",
-              "by ", AlignmentA);
-  TORCH_CHECK(k % AlignmentB == 0,
-              __func__, " : Number of columns of tensor B must be divisible ",
-              "by ", AlignmentB);
-  TORCH_CHECK(n % AlignmentC == 0,
-              __func__, " : Number of columns of tensor C must be divisible ",
-              "by ", AlignmentC);
-
-  using TensorAScaleTileThreadMap =
-      cutlass::epilogue::threadblock::OutputTileThreadLayout<
-          ThreadblockShape,
-          WarpShape,
-          ElementAScale,
-          AlignmentAScale,
-          NumEVTEpilogueStages>;
-  using TensorBScaleTileThreadMap =
-      cutlass::epilogue::threadblock::OutputTileThreadLayout<
-          ThreadblockShape,
-          WarpShape,
-          ElementBScale,
-          AlignmentBScale,
-          NumEVTEpilogueStages>;
-  using TensorCTileThreadMap =
-      cutlass::epilogue::threadblock::OutputTileThreadLayout<
-          ThreadblockShape,
-          WarpShape,
-          ElementC,
-          AlignmentC,
-          NumEVTEpilogueStages>;
-  using OutputTileThreadMap =
-      cutlass::epilogue::threadblock::OutputTileThreadLayout<
-          ThreadblockShape,
-          WarpShape,
-          ElementOutput,
-          AlignmentOutput,
-          NumEVTEpilogueStages>;
-
-  using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
-
-  using TensorAScale =
-      cutlass::epilogue::threadblock::VisitorColBroadcast<
-          TensorAScaleTileThreadMap,
-          ElementAScale,
-          cute::Stride<cute::_1, cute::_0, int64_t>>;
-  using TensorAScaleArguments = typename TensorAScale::Arguments;
-
-  using TensorBScale =
-      cutlass::epilogue::threadblock::VisitorRowBroadcast<
-          TensorBScaleTileThreadMap,
-          ElementBScale,
-          cute::Stride<cute::_0, cute::_1, int64_t>>;
-  using TensorBScaleArguments = typename TensorBScale::Arguments;
-
-  using TensorCScalar =
-      cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
-  using TensorCTensor =
-      cutlass::epilogue::threadblock::VisitorRowBroadcast<
-          TensorCTileThreadMap,
-          ElementC,
-          cute::Stride<cute::_0, cute::_1, int64_t>>;
-  using TensorC =
-      std::conditional_t<UseTensorC::value, TensorCTensor, TensorCScalar>;
-  using TensorCArguments = typename TensorC::Arguments;
-
-  using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute<
-      cutlass::multiplies, ElementEpilogue, ElementEpilogue,
-      cutlass::FloatRoundStyle::round_to_nearest
-  >;
-  using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT<
-      ApplyAScale,
-      Accum,
-      TensorAScale>;
-
-  using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute<
-      cutlass::multiplies, ElementEpilogue, ElementEpilogue,
-      cutlass::FloatRoundStyle::round_to_nearest
-  >;
-  using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT<
-      ApplyBScale,
-      EVTApplyAScale,
-      TensorBScale>;
-
-  using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
-      cutlass::plus, ElementEpilogue, ElementEpilogue,
-      cutlass::FloatRoundStyle::round_to_nearest
-  >;
-  using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
-      ApplySum,
-      EVTApplyBScale,
-      TensorC>;
-
-  using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
-      OutputTileThreadMap, ElementOutput,
-      cutlass::FloatRoundStyle::round_to_nearest,
-      cute::Stride<int64_t, cute::_1, int64_t> // StrideMNL
-  >;
-
-  using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
-      Output,
-      EVTApplySum>;
-
-  using EVTKernel =
-      typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
-      ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
-      ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
-      ElementOutput, LayoutOutput, AlignmentOutput,
-      ElementAccumulator,
-      ElementEpilogue,
-      cutlass::arch::OpClassTensorOp,
-      SmArch,
-      ThreadblockShape,
-      WarpShape,
-      InstructionShape,
-      EVTOutput,
-      ThreadblockSwizzle,
-      NumStages,
-      Operator,
-      NumEVTEpilogueStages
-  >::GemmKernel;
-
-  // GemmUniversalBase doesn't work with W4A4
-  // using Gemm = cutlass::gemm::device::GemmUniversalBase<EVTKernel>;
-  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
-
-  cutlass::gemm::GemmCoord problem_size(m, n, k);
-  constexpr auto SplitKFactor = 1;
-
-  TensorAScaleArguments tensor_a_scale_arguments{
-    (ElementAScale*)tensor_a_scale.data_ptr(),
-    ElementAScale(1),
-    {cute::_1{}, cute::_0{}, problem_size.m()}
-  };
-  TensorBScaleArguments tensor_b_scale_arguments{
-    (ElementBScale*)tensor_b_scale.data_ptr(),
-    ElementBScale(1),
-    {cute::_0{}, cute::_1{}, problem_size.n()}
-  };
-  TensorCArguments tensor_c_arguments{
-    [&]() -> TensorCArguments {
-      if constexpr (UseTensorC::value) {
-        return {(ElementC*)tensor_c.data_ptr(),
-                ElementC(0),
-                {cute::_0{}, cute::_1{}, problem_size.n()}};
-      } else {
-        return {ElementC(0)};
-      }
-    }()
-  };
-  typename Output::Arguments output_arguments{
-    (ElementOutput*)tensor_d.data_ptr(),
-    {problem_size.n(), cute::_1{}, problem_size.mn().product()}
-  };
-  typename EVTOutput::Arguments callback_arguments{
-    {
-      {
-        {
-          {},                        // Accum
-          tensor_a_scale_arguments,  // TensorAScale
-          {}                         // ApplyAScale
-        },                           // EVTApplyAScale
-        tensor_b_scale_arguments,    // TensorBScale
-        {},                          // ApplyBScale
-      },                             // EVTApplyBScale
-      tensor_c_arguments,            // TensorC
-      {}                             // ApplySum
-    },                               // EVTApplySum
-    output_arguments                 // Output
-  };                                 // EVTOutput
-  // constexpr auto AvailSms = -1;
-
-  typename Gemm::Arguments arguments(
-    cutlass::gemm::GemmUniversalMode::kGemm,
-    problem_size,
-    SplitKFactor,
-    callback_arguments,              // arguments of EVT callbacks
-    (ElementA*)tensor_a.data_ptr(),
-    (ElementB*)tensor_b.data_ptr(),
-    nullptr,                         // ptr C (unused)
-    nullptr,                         // ptr D (unused)
-    problem_size.mk().product(),     // batch stride A
-    problem_size.nk().product(),     // batch stride B
-    0,                               // batch stride C (unused)
-    0,                               // batch stride D (unused)
-    problem_size.k(),                // stride A
-    problem_size.k(),                // stride B
-    0,                               // stride C (unused)
-    0
-    // ,                               // stride D (unused)
-    // AvailSms  // GemmUniversalBase requires passing AvailSms, but GemmUniversalAdapter doesn't
-    );
-
-  Gemm gemm_op;
-
-  cutlass::Status status;
-
-  // Verify that GEMM operation with given arguments can be performed
-  // by CUTLASS.
-  status = gemm_op.can_implement(arguments);
-  CUTLASS_STATUS_CHECK(status);
-
-  // Allocate workspace for CUTLASS mixed datatypes GEMM kernel.
-  const auto workspace_size = Gemm::get_workspace_size(arguments);
-  auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
-                                      at::TensorOptions().dtype(at::kByte));
-
-  // Initialize CUTLASS mixed datatypes GEMM object.
-  status = gemm_op.initialize(arguments, workspace.data_ptr(),
-                              at::cuda::getCurrentCUDAStream());
-  CUTLASS_STATUS_CHECK(status);
-
-  // Perform mixed datatypes GEMM operation.
-  status = gemm_op.run(at::cuda::getCurrentCUDAStream());
-  CUTLASS_STATUS_CHECK(status);
-
-  C10_CUDA_KERNEL_LAUNCH_CHECK();
-}
-
-} // namespace torchao
diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
index d7374c8..bccbf80 100644
--- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
+++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
@@ -144,14 +144,14 @@ def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias
 
 
 def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
-    from torchao.ops import s8s4_linear_cutlass
+    from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
 
     weight = weight_tensor.tensor_impl.int_data
     weight_scale = weight_tensor.tensor_impl.scale
     input = input_tensor.tensor_impl.int_data
     input_scale = input_tensor.tensor_impl.scale
 
-    out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias)
+    out = rowwise_scaled_linear_cutlass_s8s4(input, input_scale, weight, weight_scale, bias)
 
     return out
 
diff --git a/torchao/ops.py b/torchao/ops.py
index 840dbc0..272b358 100644
--- a/torchao/ops.py
+++ b/torchao/ops.py
@@ -20,11 +20,10 @@ lib.define(
     "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor"
 )
 lib.define(
-    "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
+    "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
 )
-lib.define("int4_mm_cutlass(Tensor A, Tensor B) -> Tensor")
 lib.define(
-    "scaled_int4_mm_cutlass(Tensor A, Tensor B, Tensor row_scale, Tensor col_scale) -> Tensor"
+    "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
 )
 
 
@@ -518,7 +517,7 @@ def _(
     return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device)
 
 
-def s8s4_linear_cutlass(
+def rowwise_scaled_linear_cutlass_s8s4(
     input: Tensor,
     input_scale: Tensor,
     weight: Tensor,
@@ -526,23 +525,23 @@ def s8s4_linear_cutlass(
     bias: Tensor,
 ) -> Tensor:
     """
-    CUTLASS-based W4A8 linear operator.
+    CUTLASS-based row-wise scaled linear operator.
     Args:
-        input: input tensor, quantized to 8-bit integer values.
+        input: quantized input tensor, in row-major layout.
         input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension.
-        weight: weight matrix, quantized to 4-bit integer values, in row-major layout.
+        weight: quantized weight matrix, in row-major layout.
         weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension).
         bias: a vector of size equal to number of rows of weight tensor, or None.
     Returns:
         output: result tensor, in row-major layout.
     """
 
-    return torch.ops.torchao.s8s4_linear_cutlass.default(
+    return torch.ops.torchao.rowwise_scaled_linear_cutlass_s8s4.default(
         input, input_scale, weight, weight_scale, bias
     )
 
 
-@register_custom_op("torchao::s8s4_linear_cutlass")
+@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s8s4")
 def _(
     input: Tensor,
     input_scale: Tensor,
@@ -550,6 +549,8 @@ def _(
     weight_scale: Tensor,
     bias: Tensor,
 ) -> Tensor:
+    # FIXME: update this!!!
+
     # Validate dtypes.
     torch._check(
         input.dtype == torch.int8,
@@ -621,29 +622,8 @@ def _(
     )
 
 
-def int4_mm_cutlass(A: Tensor, B: Tensor) -> Tensor:
-    """
-    CUTLASS-based W4A4 matmul.
-    Args:
-        A: first INT4 tensor, packed in INT8 dtype, row-major layout.
-        B: second INT4 tensor, packed in INT8 dtype, column-major layout.
-    Returns:
-        output: result tensor, in row-major layout.
-    """
-    assert A.dtype == B.dtype == torch.int8
-    assert A.ndim == B.ndim == 2
-    assert A.shape[1] == B.shape[0]
-    assert A.is_contiguous() and B.T.is_contiguous()
-    return torch.ops.torchao.int4_mm_cutlass.default(A, B)
-
-
-@register_custom_op("torchao::int4_mm_cutlass")
-def _(A: Tensor, B: Tensor) -> Tensor:
-    return A.new_empty(A.shape[0], B.shape[1], dtype=torch.int32)
-
-
-def scaled_int4_mm_cutlass(
-    A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor
+def rowwise_scaled_linear_cutlass_s4s4(
+        A: Tensor, row_scale: Tensor, B: Tensor, col_scale: Tensor, bias: Tensor
 ) -> Tensor:
     """
     CUTLASS-based W4A4 scaled-matmul.
@@ -656,15 +636,16 @@ def scaled_int4_mm_cutlass(
         output: result tensor, in row-major layout.
     """
     assert A.dtype == B.dtype == torch.int8
-    assert A.ndim == B.ndim == 2
-    assert A.shape[1] == B.shape[0]
-    assert A.is_contiguous() and B.T.is_contiguous()
-    assert row_scale.ndim == col_scale.ndim == 1
+    assert A.ndim >= 2
+    assert B.ndim == 2
+    assert A.shape[-1] == B.shape[-1]
+    assert A.is_contiguous() and B.is_contiguous()
+    assert row_scale.ndim == col_scale.ndim == 2
     assert row_scale.dtype == col_scale.dtype
     assert row_scale.dtype in (torch.float16, torch.bfloat16)
-    return torch.ops.torchao.scaled_int4_mm_cutlass.default(A, B, row_scale, col_scale)
+    return torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4.default(A, row_scale, B, col_scale, bias)
 
 
-@register_custom_op("torchao::scaled_int4_mm_cutlass")
-def _(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor:
+@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s4s4")
+def _(A: Tensor, row_scale: Tensor, B: Tensor, col_scale: Tensor, bias: Tensor) -> Tensor:
     return row_scale.new_empty(A.shape[0], B.shape[1])

I did some renaming, so s8s4_linear_cutlass sub-directory of torchao/csrc/cuda is renamed to rowwise_scaled_linear_cutlass and then "your" files scaled_linear.h, s4s4_linear_cutlass.cu and s8s4_linear_cutlass.cu are renamed respectively to rowwise_scaled_linear_cutlass.cuh, rowwise_scaled_linear_cutlass_s4s4.cu and rowwise_scaled_linear_cutlass_s8s4.cu. The later two files are minimalistic now, they just dispatch according to the input/weight tensor data types, and everything else is in the .cuh header file. I've updated all references to my W4A8 kernel in the code, and its test and benchmark work. I've added changes needed for W4A4 kernel in the torchao/ops.py (actually, just adapted these from your scaled_int4_mm_cutlass that was there), so that I was able to add W4A4 test in tests/test_rowwise_scaled_linear_cutlass.py. Note that, for the sake of running test, this patch temporarily removes your int4_mm_cutlass and scaled_int4_mm_cutlass from torchao/ops.py - but this could be reverted, this patch is primarily to show how I'd like C++ side of changes to look like.

Please take a look and let me know what you think.

Regarding your question about possibility of different types for scales and bias - this pretty much comes for free, and I would be inclined to extend this further so that either of input/weight scales and/or bias could be of different types. On the other side, I have a related question here: looking at your int4_mm_cutlass and scaled_int4_mm_cutlass operators: do you think we may have use for the case when either of input/weight scales and/or bias could be optional? This kernel could be easily extended to support this too, without any run-time penalty.

My TODO list regarding the attached patch:

  1. The W4A4 test produces results that differ from reference results, need to investiage this.
  2. Have to re-examine rowwise_scaled_linear_cutlass_check_inputs method for S4/S4 case.
  3. Need to revisit rowwise_scaled_linear_cutlass_s8s4 and corresponding register_custom_op in the torchao/ops.py. I modeled these according to ones for Marlin-based W4A8 kernel, while on the other side I can see yours are different in the sense that the checks are in the former method instead of in the later. TBH, I don't fully get how this stuff works on the Python side of the torchao; in particular, is there any need for repeating inputs checks that, in the case of this kernel, are already there in the C++ code - do you have any suggestions here?
  4. The scales could be made optional, just like bias is, in the kernel.
  5. The code is to be linted before eventually commited.

I'll continue on this tomorrow.

@gau-nernst
Copy link
Collaborator Author

Thank you for the patch! Will apply and work on it later today.

do you think we may have use for the case when either of input/weight scales and/or bias could be optional?

I don't think there is a standalone real use case for INT4xINT4->INT32 (yet), similarly for INT8xINT8->INT32 i.e. torch._int_mm(), since we will definitely need some scaling anyway, and perf is better with fused output scaling. I initially added that since it might be interesting for benchmarking to see perf of INT4 tensor cores alone. For research, people might find the raw INT4xINT4->INT32 output without scaling useful for whatever reasons they are exploring?

TBH, I don't fully get how this stuff works on the Python side of the torchao; in particular, is there any need for repeating inputs checks that, in the case of this kernel, are already there in the C++ code - do you have any suggestions here?

From my observation, it's a good idea to have input checks for the meta impletation used by torch.compile for shape tracing (otherwise we won't catch errors during compile time, only runtime) -> we need some kind of input checks in Python anyway. Hence, usually I don't bother with input checks in C++ (or triton) side, and put all checks in Python. I place the Python checks in the Python wrapper so that it is shared between meta and CUDA implementations (actually it's regardless of the implementation) -> technically the op doesn't have input checks, but the Python wrapper does. e.g.

def scaled_int8_mm(
A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor
) -> Tensor:
"""Compute `(A @ B) * row_scale * col_scale`, where `A` and `B` are INT8 to utilize
INT8 tensor cores. `col_scale` can be a scalar.
"""
assert A.dtype is torch.int8 and B.dtype is torch.int8
assert row_scale.dtype is col_scale.dtype
assert A.shape[1] == B.shape[0]
assert row_scale.squeeze().shape == (A.shape[0],)
assert col_scale.squeeze().shape in ((B.shape[1],), ())
assert row_scale.is_contiguous()
assert col_scale.is_contiguous()
return torch.ops.torchao.scaled_int8_mm(A, B, row_scale, col_scale)

That's just my personal view (and I prefer writing Python over C++ any time 😆).

@gau-nernst
Copy link
Collaborator Author

The W4A4 test produces results that differ from reference results, need to investiage this.

Found a typo

+  int k = tensor_a.size(1);
+  if constexpr (cutlass::sizeof_bits<ElementA>::value < 8) {
+    k *= 8 % cutlass::sizeof_bits<ElementA>::value;
+  }

Should be (integer) division instead of modulo. After I fixed that, the tests pass.

I also take the liberty to slightly change the test: now matmul, scaling, and bias addition are done in FP32, which should match the numerics of the fused cutlass kernel better, especially for the scaling and bias addition parts, since FP16/BF16 matmul is acccumulated with FP32 anyway. With this change, we can use torch.testing.assert_close() directly. Let me know if you are ok with it.

With your refactoring, it should be trivial to add W8A8. Should we add W8A8 also? (We might not add it to AQT since AQT currently uses torch.compile to codegen this path. Though I think inductor only fuses 1 scaling (either input scale or weight scale), and we can get a bit more perf with fusing both input and weight scales.

Other ideas:

  • Add row-wise scaled FP8 matmul for sm89. Only need to modify logic for ElementAccumulator and SmArch I think. Will try. This can be hooked up to torch._scaled_mm() for sm89 (cc: @drisspg)
  • Runtime kernel params tuning by codegen and compile cutlass kernel at runtime 👀. Basically cutlass python (but hard-coded for row-wise scaling linear)
    • Expand on this point. One strength of triton is autotune at runtime which can adapt the best kernel params for different hardware. Usually CUDA kernels in PyTorch are tuned for A100/H100 (and also hard-coded triton configs), so there are perf left on the table. For example, for INT8 matmul on my consumer GPU, I can get 1.5-2.5x faster than torch._int_mm() simply by having better kernel params in triton.

Comment on lines +12 to +13
def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexsamardzic I replaced torchao.utils.benchmark_torch_function_in_microseconds (which is based on torch.utils.benchmark.Timer) with triton's do_bench. This is because I found PyTorch timer is unreliable, possibly because it does not clear L2 cache in between runs.

Old (4090, torch.utils.benchmark.Timer)

m k n fp16_latency (ms) W4A8 latency (ms) W4A8 speedup (d/s)
1 8192 8192 124.368 18.6586 6.66545
1 8192 10240 177.832 22.1851 8.01584
1 8192 57344 982.894 298.987 3.28742
1 28672 8192 493.988 152.724 3.23452

New (4090, triton.testing.do_bench)

m k n fp16_latency (ms) W4A8 latency (ms) W4A8 speedup (d/s)
1 8192 8192 166.912 73.728 2.26389
1 8192 10240 201.824 87.04 2.31875
1 8192 57344 1005.57 346.112 2.90533
1 28672 8192 519.136 199.68 2.59984

In the old way, you can see the unusual speedup for the first two rows. I think it's because the W is cached in L2, hence the gains disappear when W becomes larger.

Lmk if it's ok to have this change. Thank you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that's great!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For anyone else I break it down as:

  1. Do you only care about kernel time? Then use: https://github.com/drisspg/transformer_nuggets/blob/0d1cbe6b3d980451f39174be64c98a143b9cebce/transformer_nuggets/utils/benchmark.py#L55. This is what I have found produces the closest results to NCU
  2. If the overhead does matter and you are more focused on what you would see on a an isolated region of code then the timer is pretty representative https://github.com/drisspg/transformer_nuggets/blob/0d1cbe6b3d980451f39174be64c98a143b9cebce/transformer_nuggets/utils/benchmark.py#L44

@jerryzh168
Copy link
Contributor

Thanks @gau-nernst, AQT integration code LGTM

@alexsamardzic
Copy link
Collaborator

Would it be possible to apply this patch too: patch.txt? Just learned hard way that this is how it's supposed to work...

Also, I'm wondering do we need to add some compile/run time checks against anything that is less than SM80?

@gau-nernst
Copy link
Collaborator Author

@alexsamardzic Just merged your changes. I think initially I wrote it that way too, but didn't think it would make a big difference.

For runtime check, you already have one in place

template<typename ElementA, typename ElementB, typename... Types>
static void select_config(
    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
  const auto dprops = at::cuda::getCurrentDeviceProperties();
  const auto is_sm8x = dprops->major == 8;

  if (is_sm8x) {
    // early return
  }
  TORCH_CHECK(false,
              __func__, " : Operator not supported on SM", dprops->major, ".",
              dprops->minor, " for given operands");
}

For compile time, we officially only support >=sm80 (and only build wheels for >=sm80. Also we can't test for <sm80 anyway since CI doesn't have such GPUs)

# Set ARCH list so that we can build fp16 with SM75+, the logic is copied from
# pytorch/builder
TORCH_CUDA_ARCH_LIST="8.0;8.6"
if [[ ${CU_VERSION:-} == "cu124" ]]; then
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
fi

However, there have been some efforts to make sure torchao can still be compiled on sm75 e.g. #1147. My understanding is that, iirc:

  • __CUDA_ARCH__ is only defined in device code. Hence, we can only use __CUDA_ARCH__ inside the kernel (either __device__ or __global__). There are some other restrictions as outlined here https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-arch
  • If we want it to error at compile time (i.e. don't compile for target<sm80), we can add static_assert(false) within a __CUDA_ARCH__ if block (and __CUDA_ARCH__ block must be inside device code). But this is potentially bad since it prevents users from using other CUDA code in torchao that works with <sm80
  • If we still want to build it, but error at runtime, we can add __trap() in the kernel
    #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
    static_assert(false, "Quant-LLM kernel: At least Turing generation (sm75) is required.");
    // __trap(); // fails at runtime instead of compile time
    #endif
  • Using __CUDA_ARCH__ to include/exclude contents of a whole file is actually undefined behavior I think (because then you will have mismatch between device and host code - see the link above)

I'm leaning towards adding __trap() for __CUDA_ARCH__ < 800 (so ppl can still compile the rest of torchao for sm75). Lmk what you think.

Btw, if it's not obvious, the PR is ready for review. I won't be making any more major changes (unless requested). Happy to receive feedback.

@alexsamardzic
Copy link
Collaborator

I think initially I wrote it that way too, but didn't think it would make a big difference.

The difference is when an epilogue tensor is of different data type from the output tensor - then the initial variant will not work properly. I was too thinking all the time that this initial variant, with a tile thread map for each of tensors used by epilogue, is a way to go - just learned hard time, while debugging my other PR (that is extending row-wise scaling MM in implementation PyTorch) that it isn't, and that there should be only one tile thread map instead. (FWIW, found that this is the right way while looking into the code generated by CUTLASS Python library, which always generates single tile thread map.)

For runtime check, you already have one in place

Yes, but I mentioned compile time check as actually was looking yesterday into a build issue on SM75 reported by user, and noticed that the error message prints select_config - the name of function where the check is performed, that is not very informative. Thus, while I like this check to be there, where the dispatching will be performed (among other things) according to the arch (namely, I hope to have SM9x specific kernel added in the future - of course not for sub-byte int inputs, but for some other interesting combinations of input data types), I was thinking about moving TORCH_CHECK from this function to the beginning of rowwise_scaled_linear_cutlass function - so that the error message becomes more informative regarding the source of the error. Or, maybe better, just to replace __func__ in error messages with rowwise_scaled_linear_cutlass everywhere; it is easy to determine the location of the error from the rest of the message anyway.

For compile time, we officially only support >=sm80 (and only build wheels for >=sm80.

Oh, that's good to know - I was not aware. Still, I've encountered recently a nifty trick in vLLM: check the definition of enable_sm90_or_later here, and how it's used for example here. So it's an elegant way to skip the most time consuming part of the compilation when given architecture is not targeted, and the run time check should still prevent the no-op practically generated for given architecture from producing wrong results. The only slight complication here is that I'd prefer enable_smXX_or_later variations to be in a dedicated header file (I'd probably opt too for torchao/csrc/cuda/cutlass_extensions/common.cuh) in order to be able to utilize it from other CUTLASS-based kernel.

Let me know what you think on these two - if you agree, I can make a patch. Other than that, I'm approving the PR (as far as CUDA code concerned, was not looking into the rest). This is awesome, let's have it merged ASAP.

@gau-nernst
Copy link
Collaborator Author

For the runtime check, either suggestion sounds good!

For compile-time check, iiuc, enable_sm90_or_later uses the same idea as I mentioned: for unsupported arch, simply let the kernel be empty (well technically __trap() will hard-exit the program, and (I hope) the compiler is smart enough to not compile anything below __trap()). So that's good too me too!

I will apply your patch when it is ready. You should have authored this PR. All major changes are done by you 😅

@alexsamardzic
Copy link
Collaborator

Here is the patch: patch.txt.

@gau-nernst
Copy link
Collaborator Author

@drisspg When you are back, can you take another look, then I think we can merge. Thank you!

setup.py Show resolved Hide resolved
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! One nit

@drisspg drisspg merged commit 1a4c8f9 into pytorch:main Feb 5, 2025
46 of 48 checks passed
@gau-nernst gau-nernst deleted the w4a4 branch February 5, 2025 02:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] W4A4 Quantization Support in torchao
6 participants