Skip to content

Commit

Permalink
Try alternative approach
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Sep 20, 2024
1 parent d62c188 commit 45082cc
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions src/gt4py/storage/cartesian/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,35 @@
device_type=core_defs.DeviceType.CUDA,
array_utils=allocators.cupy_array_utils,
)
else:
elif CUPY_DEVICE == core_defs.DeviceType.ROCM:
_GPUBufferAllocator = allocators.NDArrayBufferAllocator(
device_type=core_defs.DeviceType.ROCM,
array_utils=allocators.cupy_array_utils,
)

class ROCmNDArray(cp.ndarray):
def __new__(cls, input_array: "cp.ndarray") -> ROCmNDArray:
return (
input_array
if isinstance(input_array, ROCmNDArray)
else cp.asarray(input_array).view(cls)
)

@property
def __cuda_array_interface__(self) -> dict:
return {
"shape": self.shape,
"typestr": self.dtype.descr[0][1],
"descr": self.dtype.descr,
"stream": 1,
"version": 3,
"strides": self.strides,
"data": (self.data.ptr, False),
}

else:
raise ValueError("Cupy is available but no suitable device was found.")


def _idx_from_order(order):
return list(np.argsort(order))
Expand Down Expand Up @@ -274,4 +297,10 @@ def allocate_gpu(
byte_alignment=alignment_bytes,
aligned_index=aligned_index,
)
return buffer.buffer, cast("cp.ndarray", buffer.ndarray)

buffer_ndarray = cast("cp.ndarray", buffer.ndarray)

if cp is not None and cp.cuda.get_hipcc_path() is not None:
buffer_ndarray = ROCmNDArray(buffer_ndarray)

return buffer.buffer, buffer_ndarray

0 comments on commit 45082cc

Please sign in to comment.