Why this cutlass_tensorop_s1688gemm_f16_64x128_64x2_tn_align4 GEMM has blocks filled with zeros? #338
Replies: 3 comments 6 replies
-
You are running a NT (col x row) gemm, not a TN (row x col) gemm. As you may notice, cutlass profiler swaps and transpose the operands. (The reason is that cutlass only has row major output, but the profiler needs to use col major output to be aligned with cublas. So A x B -> C becomes B' x A' -> C'). However, the transpose/swap of NT is still NT, transpose/swap of TN is still TN. To reduce your confusion, you can simply just not doing any swap/transpose by yourself, but let cutlass handles it (https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/gemm/device/gemm.h#L528-L533).
BTW, your reference code does not include |
Beta Was this translation helpful? Give feedback.
-
I modified example 12 to 1) use your tile size and problem size 2) do bias, but no relu 3) dump the result. Here is the diff
To run it. First do above editing, then
|
Beta Was this translation helpful? Give feedback.
-
The latest cuda 11.6 fixes -G bug. |
Beta Was this translation helpful? Give feedback.
-
After running the profiler, I found that cutlass_tensorop_s1688gemm_f16_64x128_64x2_tn_align4 is the best GEMM for my problem size. I got the parameters from the profiler generated source, then wrote a simple test program that has a cutlass::gemm::device::Gemm reproducing those parameters from cutlass_tensorop_s1688gemm_f16_64x128_64x2_tn_align4. However, the resulting matrix has blocks filled with zeros and I would like to understand why. The image compares expected value (left) and actual GEMM values (right).
Test code to create and execute the GEMM and to calculate expected values:
Beta Was this translation helpful? Give feedback.
All reactions