From dd44e7f5357e355276d26b1ec2898e057faf7d82 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 5 Mar 2024 09:55:16 +0800 Subject: [PATCH] fix multinomial sampling (#1239) Co-authored-by: grimoire --- lmdeploy/pytorch/kernels/multinomial_sampling.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lmdeploy/pytorch/kernels/multinomial_sampling.py b/lmdeploy/pytorch/kernels/multinomial_sampling.py index 476787fb8..44ed75905 100644 --- a/lmdeploy/pytorch/kernels/multinomial_sampling.py +++ b/lmdeploy/pytorch/kernels/multinomial_sampling.py @@ -21,8 +21,8 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, offset = tl.load(Offsets + off, mask=off_mask).to(tl.int32) samp = tl.rand(seed, offset)[:, None] - acc = tl.zeros((BLOCK, ), dtype=Scores.dtype.element_ty) - output = tl.full((BLOCK, ), -1, dtype=Outputs.dtype.element_ty) + acc = tl.zeros((BLOCK, ), dtype=tl.float32) + output = tl.load(Indices + off * stride_ib, mask=off_mask) for b_idx in range(0, num_tokens, BLOCK_N): s_off = b_idx + n_off @@ -30,9 +30,9 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, scores = tl.load(Scores + off[:, None] * stride_sb + s_off[None, :] * stride_st, mask=s_mask, - other=0.0) - cum_scores = acc[:, None] + tl.cumsum(scores, 1).to(acc.dtype) - acc += tl.sum(scores, 1).to(acc.dtype) + other=0.0).to(acc.dtype) + cum_scores = acc[:, None] + tl.cumsum(scores, 1) + acc += tl.sum(scores, 1) pre_cum_scores = cum_scores - scores valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores) @@ -75,7 +75,7 @@ def __kernel_meta(): assert indices.dim() == 2 assert indices.size() == scores.size() - outputs = indices.new_empty(batch_size) + outputs = indices[:, 0].clone() BLOCK = 32 BLOCK_N = 64 @@ -96,5 +96,5 @@ def __kernel_meta(): BLOCK=BLOCK, BLOCK_N=BLOCK_N, **kernel_meta) - torch.cuda.synchronize() + return outputs