Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Mar 2, 2024
1 parent f81404a commit 6e2c618
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/kernels/multinomial_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs,

samp = tl.rand(seed, offset)[:, None]
acc = tl.zeros((BLOCK, ), dtype=Scores.dtype.element_ty)
output = tl.full((BLOCK, ), -1, dtype=tl.int64)
output = tl.full((BLOCK, ), -1, dtype=Outputs.dtype.element_ty)

for b_idx in range(0, num_tokens, BLOCK_N):
s_off = b_idx + n_off
Expand All @@ -31,8 +31,8 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs,
s_off[None, :] * stride_st,
mask=s_mask,
other=0.0)
cum_scores = acc[:, None] + tl.cumsum(scores, 1)
acc += tl.sum(scores, 1)
cum_scores = acc[:, None] + tl.cumsum(scores, 1).to(acc.dtype)
acc += tl.sum(scores, 1).to(acc.dtype)

pre_cum_scores = cum_scores - scores
valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores)
Expand Down
9 changes: 8 additions & 1 deletion tests/pytorch/kernel/test_multinomial_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ def batch_size(self, select_ids):
yield len(select_ids)

@pytest.fixture
def scores(self, num_tokens, batch_size, select_ids):
def dtype(self, request):
yield request.param

@pytest.fixture
def scores(self, num_tokens, batch_size, select_ids, dtype):
ret = torch.zeros(batch_size, num_tokens).cuda()
batch_ids = torch.arange(batch_size).cuda()
ret[batch_ids, select_ids] = 1
ret = ret.to(dtype)
yield ret

@pytest.fixture
Expand All @@ -45,6 +50,8 @@ def gt(self, batch_size, select_ids, indices):
batch_ids = torch.arange(batch_size).cuda()
yield indices[batch_ids, select_ids]

@pytest.mark.parametrize('dtype',
[torch.float32, torch.half, torch.bfloat16])
@pytest.mark.parametrize(['num_tokens', 'select_ids'], [
(8, (4, 2) * 30),
(200, (50, 150)),
Expand Down

0 comments on commit 6e2c618

Please sign in to comment.