-
Notifications
You must be signed in to change notification settings - Fork 978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] make_tiled_copy_B generates incompatible layouts #1953
Comments
It does seem more reasonable, but that's not how the The solution is therefore some shared memory layout engineering. With the TiledCopy you show, we want each thread to access 128 consecutive bits, as you mention. So if I just follow the auto smem_layout_B_atom = Layout<Shape <Shape <_8, _4>,Shape <_2,_16>>,
Stride<Stride<_2,_256>,Stride<_1,_16>>>{}); // 32N x 32K
auto smem_layout_B = tile_to_shape(smem_layout_B_atom, Shape<_128,_32>{}); // 128N x 32K but I'm sure you can do better by also considering the stores from global memory and bank-accesses. |
Thank you so much for your reply! I didn't realize that I can actually manipulate the SmemAtomLayout. Before, I was simply naively doing Shape<_32, _32>, Stride<_1, _32>. In my case, since B is actually transposed, i.e., row major, I use the following SmemAtomLayout: auto smem_layout_B_atom = Layout<Shape<Shape<_8, _4>, Shape<_2, _16>>,
Stride<Stride<_2, _16>, Stride<_1, _64>>> For swizzle, I wouldn't want to swizzle the consecutive 16 indicies, so MBase = 4. I am not sure what would need to be modified or considered for the stores from global memory. I use a pretty regularized, continuous, non-swizzled GmemCopy, which uses SM80_CP_ASYNC_CACHEGLOBAL. It doesn't seem that GmemCopy can affect SmemCopy. Is there anything that I miss here? Thanks! |
What is your question?
Hello!
I am writing an int8 GEMM layer using cute.
I use
MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>
as my atom MMA, and define my tiled MMA as:For element B, my original layout is transposed, so I use
Then I define tiled copy and use the tiled copy to partition my tensor in shared memory.
Here I plot the MMA and smem_tiled_copy_B using print_latex.
mma_int8.pdf
tiled_copy_B.pdf
Good news is that the destination of smem_tiled_copy_B matches the MMA layout of B.
Bad news is that the source of smem_tiled_copy_B is arranged like ((2, 8), 2):((64, 1),16) instead of something like (16, 2):(1, 16).
I am not sure why this configuration generates the (2, 8) partition. SmemCopyAtomTransposed is constructed using SM75_U16x8_LDSM_T and int8_t, which internally should uses the ldmatrix instruction that takes in one 128-bit input each time. So it seems more reasonable for make_tiled_copy_B to have 16 continuous int8_t values in the inner dimension.
This generates errors when calling cute::copy(), as SM75_U16x8_LDSM_T for int8 is incompatible with the src layout:
Could you help me take a look at this issue? Thank you so much!
The text was updated successfully, but these errors were encountered: