diff --git a/numba_cuda/numba/cuda/stubs.py b/numba_cuda/numba/cuda/stubs.py index fa88985..fb6979a 100644 --- a/numba_cuda/numba/cuda/stubs.py +++ b/numba_cuda/numba/cuda/stubs.py @@ -1,5 +1,5 @@ """ -This script specifies all PTX special objects. +This scripts specifies all PTX special objects. """ import numpy as np from collections import defaultdict @@ -7,6 +7,7 @@ import itertools from inspect import Signature, Parameter + class Stub(object): ''' A stub object to represent special objects that are meaningless @@ -53,6 +54,7 @@ def y(self): def z(self): pass + class threadIdx(Dim3): ''' The thread indices in the current thread block. Each index is an integer diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_print.py b/numba_cuda/numba/cuda/tests/cudapy/test_print.py index 6b3ebac..93492a8 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_print.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_print.py @@ -1,4 +1,5 @@ from numba.cuda.testing import CUDATestCase, skip_on_cudasim +import numpy as np import subprocess import sys import unittest @@ -121,8 +122,8 @@ def test_string(self): def test_dim3(self): output, _ = self.run_code(printdim3_usecase) lines = [line.strip() for line in output.splitlines(True)] - expected = [str((k, j, i)) for i in range(2) for j in range(2) for k in range(2)] - self.assertEqual(sorted(lines), sorted(expected)) + expected = [str(i) for i in np.ndindex(2, 2, 2)] + self.assertEqual(sorted(lines), expected) @skip_on_cudasim('cudasim can print unlimited output') def test_too_many_args(self):