Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] make_tiled_copy_B generates incompatible layouts #1953

Open
phantaurus opened this issue Nov 20, 2024 · 2 comments
Open

[QST] make_tiled_copy_B generates incompatible layouts #1953

phantaurus opened this issue Nov 20, 2024 · 2 comments

Comments

@phantaurus
Copy link

phantaurus commented Nov 20, 2024

What is your question?
Hello!

I am writing an int8 GEMM layer using cute.

I use MMA_Atom<SM80_16x8x32_S32S8S8S32_TN> as my atom MMA, and define my tiled MMA as:

using TiledMma = TiledMMA< MMA_Atom_Arch,               
       Layout<Shape<4, _1, _1>>,
       Layout<Shape<_1, _4, _1>>>;

For element B, my original layout is transposed, so I use

using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, int8_t>;

Then I define tiled copy and use the tiled copy to partition my tensor in shared memory.

 auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomTransposed{}, tiled_mma);  

Here I plot the MMA and smem_tiled_copy_B using print_latex.
mma_int8.pdf
tiled_copy_B.pdf

Good news is that the destination of smem_tiled_copy_B matches the MMA layout of B.
Bad news is that the source of smem_tiled_copy_B is arranged like ((2, 8), 2):((64, 1),16) instead of something like (16, 2):(1, 16).

I am not sure why this configuration generates the (2, 8) partition. SmemCopyAtomTransposed is constructed using SM75_U16x8_LDSM_T and int8_t, which internally should uses the ldmatrix instruction that takes in one 128-bit input each time. So it seems more reasonable for make_tiled_copy_B to have 16 continuous int8_t values in the inner dimension.

This generates errors when calling cute::copy(), as SM75_U16x8_LDSM_T for int8 is incompatible with the src layout:

In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."

instantiation of "void cute::copy_unpack(const cute::Copy_Traits<Operation, Args...>&, const cute::Tensor<TS, SLayout> &, cute::Tensor<TD, DLayout> &) [with Operation=cute::SM75_U16x8_LDSM_T, Args=<>, TS=cute::ViewEngine<cute::smem_ptr<int8_t>>, SLayout=cute::Layout<cute::tuple<cute::tuple<cute::C<2>, cute::C<8>>>, cute::tuple<cute::tuple<int, cute::_1>>>, TD=cute::ViewEngine<int8_t *>, DLayout=cute::Layout<cute::tuple<cute::_16>, cute::tuple<cute::C<1>>>]" 

external/cutlass/include/cute/atom/copy_atom.hpp(104): here

instantiation of "void cute::Copy_Atom<cute::Copy_Traits<Args...>, T>::call(const cute::Tensor<TS, SLayout> &, cute::Tensor<TD, DLayout> &) const [with Args=<cute::SM75_U16x8_LDSM_T>, T=int8_t, TS=cute::ViewEngine<cute::smem_ptr<int8_t>>, SLayout=cute::Layout<cute::tuple<cute    ::tuple<cute::C<2>, cute::C<8>>>, cute::tuple<cute::tuple<int, cute::_1>>>, TD=cute::ViewEngine<int8_t *>, DLayout=cute::Layout<cute::tuple<cute::_16>, cute::tuple<cute::C<1>>>]" 

Could you help me take a look at this issue? Thank you so much!

@ccecka
Copy link

ccecka commented Nov 20, 2024

I am not sure why this configuration generates the (2, 8) partition. SmemCopyAtomTransposed is constructed using SM75_U16x8_LDSM_T and int8_t, which internally should uses the ldmatrix instruction that takes in one 128-bit input each time. So it seems more reasonable for make_tiled_copy_B to have 16 continuous int8_t values in the inner dimension.

It does seem more reasonable, but that's not how the LDSM_T works for int8_t input. It's hinted at in the name, the U16x8 means the LDSM_T was designed for 16bit types. We can use with int8_t types as well, but obviously we need 2xint8_ts for every U16.

The solution is therefore some shared memory layout engineering. With the TiledCopy you show, we want each thread to access 128 consecutive bits, as you mention. So if I just follow the tiled_copy_B.pdf this should work

auto smem_layout_B_atom = Layout<Shape <Shape <_8,  _4>,Shape <_2,_16>>,
                                 Stride<Stride<_2,_256>,Stride<_1,_16>>>{});  // 32N x 32K
auto smem_layout_B = tile_to_shape(smem_layout_B_atom, Shape<_128,_32>{});    // 128N x 32K

but I'm sure you can do better by also considering the stores from global memory and bank-accesses.

@phantaurus
Copy link
Author

phantaurus commented Nov 21, 2024

Thank you so much for your reply! I didn't realize that I can actually manipulate the SmemAtomLayout. Before, I was simply naively doing Shape<_32, _32>, Stride<_1, _32>.

In my case, since B is actually transposed, i.e., row major, I use the following SmemAtomLayout:

auto smem_layout_B_atom = Layout<Shape<Shape<_8, _4>, Shape<_2, _16>>,
                                 Stride<Stride<_2, _16>, Stride<_1, _64>>>

For swizzle, I wouldn't want to swizzle the consecutive 16 indicies, so MBase = 4.
My Atom B has 32 rows (for K) and 32 columns (for N). As we work on 8x rows each time, there are 2 bits left for swizzling. We wouldn't want to swizzle the columns, so we right-shift the lowest 2x column bits by 2x bits, which guanratees that thread 0 to 4 access different banks, making BBits = 2 and SShift = 2.
If my Atom size can be 64x64, I would be able to make BBits = 3 and SShift = 3, which guarantees that all 8x threads in ldmatrix access different banks.

I am not sure what would need to be modified or considered for the stores from global memory. I use a pretty regularized, continuous, non-swizzled GmemCopy, which uses SM80_CP_ASYNC_CACHEGLOBAL. It doesn't seem that GmemCopy can affect SmemCopy. Is there anything that I miss here? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants