diff --git a/kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py b/kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py index dd33665..31b2e05 100644 --- a/kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py +++ b/kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py @@ -22,7 +22,10 @@ def col_major(pid, grid_m = tl.cdiv(m, block_m) grid_n = tl.cdiv(n, block_n) - pid_m = (pid % grid_n) + # There is a bug. + # pid_m = (pid % grid_n) + # The result is correct, but the speedup is not as good as mentioned in the documentation + pid_m = (pid % grid_m) pid_n = pid // grid_m return pid_m, pid_n