Skip to content

Commit

Permalink
Add device-side support for int.bit_count (which just lowers to cuda.…
Browse files Browse the repository at this point in the history
…popc).

Expand tests for cuda.popc to include smaller integer types.
  • Loading branch information
brycelelbach committed Oct 4, 2024
1 parent 9c31f59 commit c8690d8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
6 changes: 5 additions & 1 deletion numba_cuda/numba/cuda/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numba.core import cgutils
from numba.core.errors import RequireLiteralValue
from numba.core.typing import signature
from numba.core.extending import overload_attribute
from numba.core.extending import overload_attribute, overload_method
from numba.cuda import nvvmutils
from numba.cuda.extending import intrinsic

Expand Down Expand Up @@ -196,3 +196,7 @@ def syncthreads_or(typingctx, predicate):
'''
fname = 'llvm.nvvm.barrier0.or'
return _syncthreads_predicate(typingctx, predicate, fname)

@overload_method(types.Integer, 'bit_count', target='cuda')
def integer_bit_count(i):
return lambda i: cuda.popc(i)
48 changes: 44 additions & 4 deletions numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def simple_popc(ary, c):
ary[0] = cuda.popc(c)


def simple_bit_count(ary, c):
ary[0] = c.bit_count()


def simple_fma(ary, a, b, c):
ary[0] = cuda.fma(a, b, c)

Expand Down Expand Up @@ -550,17 +554,53 @@ def foo(out):

self.assertTrue(np.all(arr))

def test_popc_u1(self):
compiled = cuda.jit("void(int32[:], uint8)")(simple_popc)
ary = np.zeros(1, dtype=np.int8)
compiled[1, 1](ary, np.uint8(0xFF))
self.assertEqual(ary[0], 8)

def test_popc_u2(self):
compiled = cuda.jit("void(int32[:], uint16)")(simple_popc)
ary = np.zeros(1, dtype=np.int16)
compiled[1, 1](ary, np.uint16(0xFFFF))
self.assertEqual(ary[0], 16)

def test_popc_u4(self):
compiled = cuda.jit("void(int32[:], uint32)")(simple_popc)
ary = np.zeros(1, dtype=np.int32)
compiled[1, 1](ary, 0xF0)
self.assertEqual(ary[0], 4)
compiled[1, 1](ary, np.uint32(0xFFFFFFFF))
self.assertEqual(ary[0], 32)

def test_popc_u8(self):
compiled = cuda.jit("void(int32[:], uint64)")(simple_popc)
ary = np.zeros(1, dtype=np.int32)
compiled[1, 1](ary, 0xF00000000000)
self.assertEqual(ary[0], 4)
compiled[1, 1](ary, np.uint64(0xFFFFFFFFFFFFFFFF))
self.assertEqual(ary[0], 64)

def test_bit_count_u1(self):
compiled = cuda.jit("void(int32[:], uint8)")(simple_bit_count)
ary = np.zeros(1, dtype=np.int8)
compiled[1, 1](ary, np.uint8(0xFF))
self.assertEqual(ary[0], 8)

def test_bit_count_u2(self):
compiled = cuda.jit("void(int32[:], uint16)")(simple_bit_count)
ary = np.zeros(1, dtype=np.int16)
compiled[1, 1](ary, np.uint16(0xFFFF))
self.assertEqual(ary[0], 16)

def test_bit_count_u4(self):
compiled = cuda.jit("void(int32[:], uint32)")(simple_bit_count)
ary = np.zeros(1, dtype=np.int32)
compiled[1, 1](ary, np.uint32(0xFFFFFFFF))
self.assertEqual(ary[0], 32)

def test_bit_count_u8(self):
compiled = cuda.jit("void(int32[:], uint64)")(simple_bit_count)
ary = np.zeros(1, dtype=np.int32)
compiled[1, 1](ary, np.uint64(0xFFFFFFFFFFFFFFFF))
self.assertEqual(ary[0], 64)

def test_fma_f4(self):
compiled = cuda.jit("void(f4[:], f4, f4, f4)")(simple_fma)
Expand Down

0 comments on commit c8690d8

Please sign in to comment.