diff --git a/lmdeploy/pytorch/kernels/multinomial_sampling.py b/lmdeploy/pytorch/kernels/multinomial_sampling.py index 9ad29c773..476787fb8 100644 --- a/lmdeploy/pytorch/kernels/multinomial_sampling.py +++ b/lmdeploy/pytorch/kernels/multinomial_sampling.py @@ -22,7 +22,7 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, samp = tl.rand(seed, offset)[:, None] acc = tl.zeros((BLOCK, ), dtype=Scores.dtype.element_ty) - output = tl.full((BLOCK, ), -1, dtype=tl.int64) + output = tl.full((BLOCK, ), -1, dtype=Outputs.dtype.element_ty) for b_idx in range(0, num_tokens, BLOCK_N): s_off = b_idx + n_off @@ -31,8 +31,8 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, s_off[None, :] * stride_st, mask=s_mask, other=0.0) - cum_scores = acc[:, None] + tl.cumsum(scores, 1) - acc += tl.sum(scores, 1) + cum_scores = acc[:, None] + tl.cumsum(scores, 1).to(acc.dtype) + acc += tl.sum(scores, 1).to(acc.dtype) pre_cum_scores = cum_scores - scores valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores) diff --git a/tests/pytorch/kernel/test_multinomial_sampling.py b/tests/pytorch/kernel/test_multinomial_sampling.py index 4512a3487..f0f594dde 100644 --- a/tests/pytorch/kernel/test_multinomial_sampling.py +++ b/tests/pytorch/kernel/test_multinomial_sampling.py @@ -19,10 +19,15 @@ def batch_size(self, select_ids): yield len(select_ids) @pytest.fixture - def scores(self, num_tokens, batch_size, select_ids): + def dtype(self, request): + yield request.param + + @pytest.fixture + def scores(self, num_tokens, batch_size, select_ids, dtype): ret = torch.zeros(batch_size, num_tokens).cuda() batch_ids = torch.arange(batch_size).cuda() ret[batch_ids, select_ids] = 1 + ret = ret.to(dtype) yield ret @pytest.fixture @@ -45,6 +50,8 @@ def gt(self, batch_size, select_ids, indices): batch_ids = torch.arange(batch_size).cuda() yield indices[batch_ids, select_ids] + @pytest.mark.parametrize('dtype', + [torch.float32, torch.half, torch.bfloat16]) @pytest.mark.parametrize(['num_tokens', 'select_ids'], [ (8, (4, 2) * 30), (200, (50, 150)),