Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sleepcoo committed Feb 16, 2025
1 parent 2aa081b commit a0d0d11
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
8 changes: 2 additions & 6 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,8 @@ def grouped_gemm_triton_kernel(

if group_k > 0 and group_n > 0:
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
scale_b
+ (expert_id * bs_stride_0)
+ (n_range_start + offs_bsn) * bs_stride_1
)
offs_bsn = (n_range_start + offs_bn) // group_n
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def _load_fp8_scale(

class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self):
super().__init__()
self.block_quant = False

def create_weights(
Expand Down

0 comments on commit a0d0d11

Please sign in to comment.