diff --git a/numba_cuda/numba/cuda/tests/nrt/mock_numpy.py b/numba_cuda/numba/cuda/tests/nrt/mock_numpy.py index bc3dfba..a59df57 100644 --- a/numba_cuda/numba/cuda/tests/nrt/mock_numpy.py +++ b/numba_cuda/numba/cuda/tests/nrt/mock_numpy.py @@ -1,9 +1,8 @@ -import numpy as np - from numba.core import errors, types from numba.core.extending import overload from numba.np.arrayobj import (_check_const_str_dtype, is_nonelike, - ty_parse_dtype, ty_parse_shape, numpy_empty_nd) + ty_parse_dtype, ty_parse_shape, numpy_empty_nd, + numpy_empty_like_nd) # Typical tests for allocation use array construction (e.g. np.zeros, np.empty, @@ -48,12 +47,20 @@ def impl(shape, dtype): @overload(cuda_empty_like) -def ol_cuda_empty_like(a, dtype=None): - _check_const_str_dtype("zeros_like", dtype) - - # NumPy uses 'a' as the arg name for the array-like - def impl(a, dtype=None): - arr = np.empty_like(a, dtype=dtype) - arr._zero_fill() - return arr +def ol_cuda_empty_like(arr): + + if isinstance(arr, types.Array): + nb_dtype = arr.dtype + else: + nb_dtype = arr + + if isinstance(arr, types.Array): + layout = arr.layout if arr.layout != 'A' else 'C' + retty = arr.copy(dtype=nb_dtype, layout=layout, readonly=False) + else: + retty = types.Array(nb_dtype, 0, 'C') + + def impl(arr): + dtype = None + return numpy_empty_like_nd(arr, dtype, retty) return impl