Skip to content

Commit

Permalink
fix multinomial sampling (#1239)
Browse files Browse the repository at this point in the history
Co-authored-by: grimoire <[email protected]>
  • Loading branch information
grimoire and grimoire authored Mar 5, 2024
1 parent 7dd97fd commit dd44e7f
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions lmdeploy/pytorch/kernels/multinomial_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs,
offset = tl.load(Offsets + off, mask=off_mask).to(tl.int32)

samp = tl.rand(seed, offset)[:, None]
acc = tl.zeros((BLOCK, ), dtype=Scores.dtype.element_ty)
output = tl.full((BLOCK, ), -1, dtype=Outputs.dtype.element_ty)
acc = tl.zeros((BLOCK, ), dtype=tl.float32)
output = tl.load(Indices + off * stride_ib, mask=off_mask)

for b_idx in range(0, num_tokens, BLOCK_N):
s_off = b_idx + n_off
s_mask = off_mask[:, None] & (s_off[None, :] < num_tokens)
scores = tl.load(Scores + off[:, None] * stride_sb +
s_off[None, :] * stride_st,
mask=s_mask,
other=0.0)
cum_scores = acc[:, None] + tl.cumsum(scores, 1).to(acc.dtype)
acc += tl.sum(scores, 1).to(acc.dtype)
other=0.0).to(acc.dtype)
cum_scores = acc[:, None] + tl.cumsum(scores, 1)
acc += tl.sum(scores, 1)

pre_cum_scores = cum_scores - scores
valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores)
Expand Down Expand Up @@ -75,7 +75,7 @@ def __kernel_meta():
assert indices.dim() == 2
assert indices.size() == scores.size()

outputs = indices.new_empty(batch_size)
outputs = indices[:, 0].clone()

BLOCK = 32
BLOCK_N = 64
Expand All @@ -96,5 +96,5 @@ def __kernel_meta():
BLOCK=BLOCK,
BLOCK_N=BLOCK_N,
**kernel_meta)
torch.cuda.synchronize()

return outputs

0 comments on commit dd44e7f

Please sign in to comment.