diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 5a1e7c4706..95d33e1a0e 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -55,12 +55,35 @@ device_type=core_defs.DeviceType.CUDA, array_utils=allocators.cupy_array_utils, ) - else: + elif CUPY_DEVICE == core_defs.DeviceType.ROCM: _GPUBufferAllocator = allocators.NDArrayBufferAllocator( device_type=core_defs.DeviceType.ROCM, array_utils=allocators.cupy_array_utils, ) + class ROCmNDArray(cp.ndarray): + def __new__(cls, input_array: "cp.ndarray") -> ROCmNDArray: + return ( + input_array + if isinstance(input_array, ROCmNDArray) + else cp.asarray(input_array).view(cls) + ) + + @property + def __cuda_array_interface__(self) -> dict: + return { + "shape": self.shape, + "typestr": self.dtype.descr[0][1], + "descr": self.dtype.descr, + "stream": 1, + "version": 3, + "strides": self.strides, + "data": (self.data.ptr, False), + } + + else: + raise ValueError("Cupy is available but no suitable device was found.") + def _idx_from_order(order): return list(np.argsort(order)) @@ -274,4 +297,10 @@ def allocate_gpu( byte_alignment=alignment_bytes, aligned_index=aligned_index, ) - return buffer.buffer, cast("cp.ndarray", buffer.ndarray) + + buffer_ndarray = cast("cp.ndarray", buffer.ndarray) + + if cp is not None and cp.cuda.get_hipcc_path() is not None: + buffer_ndarray = ROCmNDArray(buffer_ndarray) + + return buffer.buffer, buffer_ndarray