diff --git a/lmdeploy/pytorch/kernels/biased_pagedattention.py b/lmdeploy/pytorch/kernels/biased_pagedattention.py index 4384ded676..1270c17e7b 100644 --- a/lmdeploy/pytorch/kernels/biased_pagedattention.py +++ b/lmdeploy/pytorch/kernels/biased_pagedattention.py @@ -7,6 +7,22 @@ assert triton.__version__ >= '2.1.0' +_NV_CAP = torch.cuda.get_device_capability() +if _NV_CAP[0] >= 8: + + @triton.jit + def _convert_pv(p, v): + """convert pv.""" + p = p.to(v.dtype) + return p, v +else: + + @triton.jit + def _convert_pv(p, v): + """convert pv.""" + v = v.to(p.dtype) + return p, v + @triton.jit def _fwd_kernel( @@ -124,7 +140,7 @@ def _fwd_kernel( other=0.0, ) - p = p.to(v.dtype) + p, v = _convert_pv(p, v) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new diff --git a/lmdeploy/pytorch/kernels/pagedattention.py b/lmdeploy/pytorch/kernels/pagedattention.py index e88cd1f874..31c0694320 100644 --- a/lmdeploy/pytorch/kernels/pagedattention.py +++ b/lmdeploy/pytorch/kernels/pagedattention.py @@ -198,6 +198,23 @@ def _reduce_split_kernel( tl.store(Out + out_offs, acc) +_NV_CAP = torch.cuda.get_device_capability() +if _NV_CAP[0] >= 8: + + @triton.jit + def _convert_pv(p, v): + """convert pv.""" + p = p.to(v.dtype) + return p, v +else: + + @triton.jit + def _convert_pv(p, v): + """convert pv.""" + v = v.to(p.dtype) + return p, v + + @triton.jit def _fwd_kernel( Q, @@ -307,7 +324,7 @@ def _fwd_kernel( acc = acc * alpha[:, None] # update acc - p = p.to(v.dtype) + p, v = _convert_pv(p, v) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new