Skip to content

Commit

Permalink
copy from numba PR #8458
Browse files Browse the repository at this point in the history
  • Loading branch information
dlee992 authored and gmarkall committed Oct 7, 2024
1 parent 9ed01c5 commit 7928f2d
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 9 deletions.
38 changes: 29 additions & 9 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
'hrcp', 'hrint',
'htrunc', 'hdiv']

reshape_funcs = ['nocopy_empty_reshape', 'numba_attempt_nocopy_reshape']


class _Kernel(serialize.ReduceMixin):
'''
Expand Down Expand Up @@ -105,15 +107,33 @@ def __init__(self, py_func, argtypes, link=None, debug=False,
if self.cooperative:
lib.needs_cudadevrt = True

res = [fn for fn in cuda_fp16_math_funcs
if (f'__numba_wrapper_{fn}' in lib.get_asm_str())]

if res:
# Path to the source containing the foreign function
basedir = os.path.dirname(os.path.abspath(__file__))
functions_cu_path = os.path.join(basedir,
'cpp_function_wrappers.cu')
link.append(functions_cu_path)
def link_to_library_functions(library_functions, library_path,
prefix=None):
"""
Dynamically links to library functions by searching for their names
in the specified library and linking to the corresponding source
file.
"""
if prefix is not None:
library_functions = [f"{prefix}{fn}" for fn in
library_functions]

found_functions = [fn for fn in library_functions
if f'{fn}' in lib.get_asm_str()]

if found_functions:
basedir = os.path.dirname(os.path.abspath(__file__))
source_file_path = os.path.join(basedir, library_path)
link.append(source_file_path)

return found_functions

# Link to the helper library functions if needed
link_to_library_functions(reshape_funcs, 'reshape_funcs.cu')
# Link to the CUDA FP16 math library functions if needed
link_to_library_functions(cuda_fp16_math_funcs,
'cpp_function_wrappers.cu',
'__numba_wrapper_')

for filepath in link:
lib.add_linking_file(filepath)
Expand Down
151 changes: 151 additions & 0 deletions numba_cuda/numba/cuda/reshape_funcs.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Handle reshaping of zero-sized array.
* See numba_attempt_nocopy_reshape() below.
*/
#define NPY_MAXDIMS 32

typedef long int npy_intp;

extern "C" __device__ int
nocopy_empty_reshape(npy_intp nd, const npy_intp *dims, const npy_intp *strides,
npy_intp newnd, const npy_intp *newdims,
npy_intp *newstrides, npy_intp itemsize,
int is_f_order)
{
int i;
/* Just make the strides vaguely reasonable
* (they can have any value in theory).
*/
for (i = 0; i < newnd; i++)
newstrides[i] = itemsize;
return 1; /* reshape successful */
}

/*
* Straight from Numpy's _attempt_nocopy_reshape()
* (np/core/src/multiarray/shape.c).
* Attempt to reshape an array without copying data
*
* This function should correctly handle all reshapes, including
* axes of length 1. Zero strides should work but are untested.
*
* If a copy is needed, returns 0
* If no copy is needed, returns 1 and fills `npy_intp *newstrides`
* with appropriate strides
*/
extern "C" __device__ int
numba_attempt_nocopy_reshape(npy_intp nd, const npy_intp *dims, const npy_intp *strides,
npy_intp newnd, const npy_intp *newdims,
npy_intp *newstrides, npy_intp itemsize,
int is_f_order)
{
int oldnd;
npy_intp olddims[NPY_MAXDIMS];
npy_intp oldstrides[NPY_MAXDIMS];
npy_intp np, op, last_stride;
int oi, oj, ok, ni, nj, nk;

oldnd = 0;
/*
* Remove axes with dimension 1 from the old array. They have no effect
* but would need special cases since their strides do not matter.
*/
for (oi = 0; oi < nd; oi++) {
if (dims[oi]!= 1) {
olddims[oldnd] = dims[oi];
oldstrides[oldnd] = strides[oi];
oldnd++;
}
}

np = 1;
for (ni = 0; ni < newnd; ni++) {
np *= newdims[ni];
}
op = 1;
for (oi = 0; oi < oldnd; oi++) {
op *= olddims[oi];
}
if (np != op) {
/* different total sizes; no hope */
return 0;
}

if (np == 0) {
/* the Numpy code does not handle 0-sized arrays */
return nocopy_empty_reshape(nd, dims, strides,
newnd, newdims, newstrides,
itemsize, is_f_order);
}

/* oi to oj and ni to nj give the axis ranges currently worked with */
oi = 0;
oj = 1;
ni = 0;
nj = 1;
while (ni < newnd && oi < oldnd) {
np = newdims[ni];
op = olddims[oi];

while (np != op) {
if (np < op) {
/* Misses trailing 1s, these are handled later */
np *= newdims[nj++];
} else {
op *= olddims[oj++];
}
}

/* Check whether the original axes can be combined */
for (ok = oi; ok < oj - 1; ok++) {
if (is_f_order) {
if (oldstrides[ok+1] != olddims[ok]*oldstrides[ok]) {
/* not contiguous enough */
return 0;
}
}
else {
/* C order */
if (oldstrides[ok] != olddims[ok+1]*oldstrides[ok+1]) {
/* not contiguous enough */
return 0;
}
}
}

/* Calculate new strides for all axes currently worked with */
if (is_f_order) {
newstrides[ni] = oldstrides[oi];
for (nk = ni + 1; nk < nj; nk++) {
newstrides[nk] = newstrides[nk - 1]*newdims[nk - 1];
}
}
else {
/* C order */
newstrides[nj - 1] = oldstrides[oj - 1];
for (nk = nj - 1; nk > ni; nk--) {
newstrides[nk - 1] = newstrides[nk]*newdims[nk];
}
}
ni = nj++;
oi = oj++;
}

/*
* Set strides corresponding to trailing 1s of the new shape.
*/
if (ni >= 1) {
last_stride = newstrides[ni - 1];
}
else {
last_stride = itemsize;
}
if (is_f_order) {
last_stride *= newdims[ni - 1];
}
for (nk = ni; nk < newnd; nk++) {
newstrides[nk] = last_stride;
}

return 1;
}
74 changes: 74 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,31 @@
from unittest.mock import call, patch


def array_reshape1d(arr, newshape, got):
y = arr.reshape(newshape)
for i in range(y.shape[0]):
got[i] = y[i]


def array_reshape2d(arr, newshape, got):
y = arr.reshape(newshape)
for i in range(y.shape[0]):
for j in range(y.shape[1]):
got[i, j] = y[i, j]


def array_reshape3d(arr, newshape, got):
y = arr.reshape(newshape)
for i in range(y.shape[0]):
for j in range(y.shape[1]):
for k in range(y.shape[2]):
got[i, j, k] = y[i, j, k]


def array_reshape(arr, newshape):
return arr.reshape(newshape)


@skip_on_cudasim('CUDA Array Interface is not supported in the simulator')
class TestCudaArrayInterface(ContextResettingTestCase):
def assertPointersEqual(self, a, b):
Expand Down Expand Up @@ -430,6 +455,55 @@ def f(x, y):
# Ensure that synchronize was not called
mock_sync.assert_not_called()

# @skip_unless_cuda_python('NVIDIA Binding needed for NVRTC')
def test_array_reshape(self):
def check(pyfunc, kernelfunc, arr, shape):
kernel = cuda.jit(kernelfunc)
expected = pyfunc(arr, shape)
got = np.zeros(expected.shape, dtype=arr.dtype)
kernel[1, 1](arr, shape, got)
self.assertPreciseEqual(got, expected)

def check_only_shape(kernelfunc, arr, shape, expected_shape):
kernel = cuda.jit(kernelfunc)
got = np.zeros(expected_shape, dtype=arr.dtype)
kernel[1, 1](arr, shape, got)
self.assertEqual(got.shape, expected_shape)
self.assertEqual(got.size, arr.size)

# 0-sized arrays
def check_empty(arr):
check(array_reshape, array_reshape1d, arr, 0)
check(array_reshape, array_reshape1d, arr, (0,))
check(array_reshape, array_reshape3d, arr, (1, 0, 2))
check_only_shape(array_reshape2d, arr, (0, -1), (0, 0))
check_only_shape(array_reshape2d, arr, (4, -1), (4, 0))
check_only_shape(array_reshape3d, arr, (-1, 0, 4), (0, 0, 4))

# C-contiguous
arr = np.arange(24)
check(array_reshape, array_reshape1d, arr, (24,))
check(array_reshape, array_reshape2d, arr, (4, 6))
check(array_reshape, array_reshape2d, arr, (8, 3))
check(array_reshape, array_reshape3d, arr, (8, 1, 3))

arr = np.arange(24).reshape((1, 8, 1, 1, 3, 1))
check(array_reshape, array_reshape1d, arr, (24,))
check(array_reshape, array_reshape2d, arr, (4, 6))
check(array_reshape, array_reshape2d, arr, (8, 3))
check(array_reshape, array_reshape3d, arr, (8, 1, 3))

# Test negative shape value
arr = np.arange(25).reshape(5,5)
check(array_reshape, array_reshape1d, arr, -1)
check(array_reshape, array_reshape1d, arr, (-1,))
check(array_reshape, array_reshape2d, arr, (-1, 5))
check(array_reshape, array_reshape3d, arr, (5, -1, 5))
check(array_reshape, array_reshape3d, arr, (5, 5, -1))

arr = np.array([])
check_empty(arr)


if __name__ == "__main__":
unittest.main()

0 comments on commit 7928f2d

Please sign in to comment.