-
Notifications
You must be signed in to change notification settings - Fork 981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] TMA Cooperative GeMM with Stream-K scheduler hangs #1917
Comments
Hi, @NihalPotdar. Can you please provide a reproducer for this bug? |
@jackkosaian set the scheduler in examples/55_hooper_mixed_dtype_gemm.cu to |
please include your build flags and full steps to repro starting at a checkout of the repo. We find that often users do not use our build system generated flags. Please also provide your CUDA toolkit version |
|
Thanks for the detailed steps. It looks like that example is not calling This is required for stream-K in order to initialize counters used for coordinating inter-CTA reduction. If these are not properly initialized, stream-K is likely to hang. Can you please try changing the loop linked above to be the following? for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); // Added this line
CUTLASS_CHECK(gemm.run());
} |
@jackkosaian that works, thank you. However, for smaller problem shapes, like in this case where the atomic write overheard becomes significant - using a separate reduction wave would make a lot of sense. I noticed that in the cutlass implementation, https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/tile_scheduler_params.h#L1051, this is turned off currently. Are there any plans to fix this and what's the ETA there? do you have any suggested workarounds in the mean time? |
The current plan is to improve reduction performance in 3.7. |
@jackkosaian Sounds good. I am also seeing correctness issues with the existing implementation, that is when compared with fp8 like in |
How are tensors initialized? Containing random floating point values, or random integer values within some tight range? How is error checking being performed? Exact match or relative? Since stream-K involves splitting a GEMM along the K mode, it can accumulate results in a different order than a non-stream-K GEMM. Since floating point addition is not associative, the results can be different. |
@jackkosaian Sorry for the delayed reply. This is directly based on the For this specific case, this is testing matrix B as The source of truth for the verification is fp8*fp8 matmul with the same inputs. The error checking is being performed with I understand that there's non-determinism coming from the hardware, but think these bounds are very lenient and I would expect this verification to succeed. Otherwise, this makes it really hard to use for any practical applications which need to preserve quality like with quantized models. |
Thanks for the details. Can you please post a full reproducer so that I can make sure I'm testing the same thing as you? |
Describe the bug
Gemm kernels with the following configurations hang for specific gemm shapes.
Type: uint4_t * half_t
Tile: m=16,n=2560,k=8192
Cluster: 1x1x1
Kernel Schedule: KernelTmaWarpSpecializedCooperative
Epilogue Schedule: TmaWarpSpecializedCooperative
Tile Scheduler: Stream-K
Expected behavior
With Cutlass 3.X, this kernel just hangs call with no changes. This is not expected.
Based on #1801, this should have been resolved.
The text was updated successfully, but these errors were encountered: