Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] How are DRAM accesses optimized during Split K - reduction across threadblocks? #1406

Closed
Rya-Sanovar opened this issue Mar 16, 2024 · 18 comments

Comments

@Rya-Sanovar
Copy link

Assuming partition=16 in PartitionedK GEMM during split-k reduction at threadblock level:
All of the 16 tiles belonging to the same "row" of block tiles in A matrix run on different SM's of a GPU. During the reduction phase, some of these tiles will need to be written to global memory so they can be loaded in by the consumer tiles that will perform the reduction.
Wouldn't this increase latency due to the DRAM BW bottleneck, since we now have extra gmem accesses compared to the un-partitioned case, even though we've achieved higher occupancy? How does cutlass optimize this?

@thakkarV
Copy link
Collaborator

thakkarV commented Mar 16, 2024

Same as before - we optimize the CTA rasterization and work tile mapping such that they hit in L2 as much as possible. This applies to the partial reduction as well.

@Rya-Sanovar
Copy link
Author

okay, doing this maximizes GPU occupancy, but what about latency? Isn't this a tradeoff here? Or by optimizing L2 hit rate does split-k reduction have almost the same inference latency as without it?

@thakkarV
Copy link
Collaborator

I'm not sure what you mean by this being a tradeoff between occupancy and latency. Latency of what? Individual gmem accesses or the time to completion of the kernel?

Split K increases utilization. Therefore the wallclock time of the kernel goes down. Regardless of gmem latencies for individual tiles, this implies that the end to end latency of the kernel decreases.

@Rya-Sanovar
Copy link
Author

By latency I mean the total time it takes to compute the C matrix if we used split-k reduction, so yes, time to completion of the kernel.

So, if the wallclock time goes down that means that the additional parallelism that split-k achieves compensates for the extra gmem accesses, right? Does this happen for all different cases of matrix sizes M,N,K? as in, can there be some cases where using split-k actually slows down the time it takes to compute C?

@thakkarV
Copy link
Collaborator

Yes. Split K nets you speedups when MN are small and K is large. If the outer dims are big then using split K will worsen perf, especially if the contraction dim is small.

@Rya-Sanovar
Copy link
Author

  1. Can we say that the speedup with split K would be almost (#tiles along K mode)x the case without split K? Assuming that total #block tiles doesn't exceed #SM's on the GPU
  2. And how would we know if K is "large" enough to use split K, does it only have to be relatively bigger than M and N, or is there a certain value of K after which split K would actually give us speedup?

@thakkarV
Copy link
Collaborator

  1. yes
  2. it depends on a lot of factors, such as exact problem size and architecture etc. you can use the profiler to decide this splitting factor. mind you, split K is not really a great load balancing strategy in general

@Rya-Sanovar
Copy link
Author

Rya-Sanovar commented Mar 18, 2024

mind you, split K is not really a great load balancing strategy in general

  1. Could you elaborate on why this is?
  2. How exactly is the split K reduction executed? Is it something similar to, say, how intra-warp reduction is done by shfl_down_sync()? Attaching an image for reference:
    image

@thakkarV
Copy link
Collaborator

  1. I recommend you read the stream k paper (linked in our readme) to understand the intricacies of load balancing.
  2. Two ways to do reductions. First is what you show, which is parallel reductions that need to use atomics. Not tree based generally since CTA scheduling can be dynamic. There second method is serial reductions with semaphores.

@Rya-Sanovar
Copy link
Author

Thanks! Can't find it on readme, can you link it here please?

@thakkarV
Copy link
Collaborator

@Rya-Sanovar
Copy link
Author

Rya-Sanovar commented Mar 20, 2024

Thanks. Also why is the split-k's job split into two kernels: GemmKernel and ReductionKernel? Wouldn't launching it as 1 fused kernel be less costly?

@hwu36
Copy link
Collaborator

hwu36 commented Mar 20, 2024

parallel splitk uses a separate reduction kernel. serial splitk fused the reduction inside the kernel. the former works better when split slices are big.

@Rya-Sanovar
Copy link
Author

I see. I have a few questions on how parallel splitk works:

  1. before reduction kernel is launched, all the partial products are present in smem of the SM's right? Or are they present in workspace (which resides in DRAM if I'm not wrong)?
  2. why are CTA shapes in GemmKernel and ReductionKernel for split-k different? And why can't it be one fused kernel instead of launching 2 separate ones?
  3. How exactly does parallel splitk reduction work if kPartitionsPerStage=4 and split_k_slices=16 for example?

@hwu36
Copy link
Collaborator

hwu36 commented Mar 26, 2024

Partial products are in the global memory (workspace) no matter it is serial or parallel splitk. serial splitk does not need a separate reduction kernel, but parallel splitk needs one.

If split_k_slices=16, there will be 16 partitions. kPartitionsPerStage=4 means we reduce 4 partitions every time.

@Rya-Sanovar
Copy link
Author

@thakkarV In the streamK paper, they've compared “two-tile Stream-K + data-parallel" against data parallel CUTLASS, but I was wondering what the performance difference (in terms of both latency and occupancy) is for:

  1. “two-tile Stream-K + data-parallel" vs basic streamK
  2. “two-tile Stream-K + data-parallel" vs parallel split-K CUTLASS
  3. Basic streamK vs parallel split-K CUTLASS.

In short, how much does the suboptimal cache hits in basic streamK affect performance and how does streamK compare to split-k.

I understand this must be contingent to problem sizes and hardware, but what's the general consensus?

@cloudhan
Copy link

I don't think there will be "general consensus" which one is better. They might be just be trail and error results. What they don't mention in the paper (explicitly) is that they are moving from "data parallel" to "task parallel". And the "tasking" is causing some kind of contention and they want to amortize.

So the covert plot behind the scenes (might be, educated guess):

  1. The GPU is getting powerful and number of SMs goes up but the problem size will not change because of old algorithms.
    • Thus the "quantization inefficiency" in the tail part of the DP GEMM, or reduction overhead of split-K GEMM is getting larger and larger.
  2. If m,n is fixed, Compute and Bandwidth requirement is basically O(k). So they decide to assign k as large as possible (instead of fixed) and assign the task to specific CTAs.
    • This means less CTAs being launched, at best O(num_SMs)
    • This means more reduction along k-axis can happen toward registers or shared memory, amortizing split-k reduction overhead
  3. They still need to load balance the SMs (hardware resource).
    • The k is a little bit larger to achieve best occupancy.
    • Sometimes a "task" needs a split along k (somewhere in between), thus global reduction
  4. Those split parts again cause reduction overhead.
    • They are doing single kernel computing, inplace reduction relies on somekind atomic/memory consistency and coherency.
  5. They find hardware guys are lazy xD
    • L2 cache is in (two) partitions in their cash cow (A100, H100, maybe B?00).
    • For different problem sizes and hardware combinations, the naive stream-K might perform worse due to the L2 contention.
      • Some problem requires excessive global reduction, say MxN are small (basically degenerate to split K?)
  6. Partitioned L2 is irritating.
    • DP + one-tile SK and Two-tile SK + DP are develop to cover some cases.
      • They might just want to distribute the synchronization point across program execution trace to amortize the latency.

I believe the partitioned far+near L2 cache plays an important role here. They don't mention it in the paper as the the GPU is "hypothetical" in the paper.

Copy link

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants