Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
DeMoriarty committed Feb 12, 2022
1 parent 51ecc87 commit b323af6
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
6 changes: 3 additions & 3 deletions torchpq/kernels/CustomKernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import cupy as cp
import torch
from torchpq.kernels import DEVICE
from torchpq.kernels.default_device import get_default_device

@cp.memoize(for_each_device=True)
def cunnex(func_name, func_body):
Expand All @@ -13,11 +13,11 @@ def __init__(self, ptr):
class CustomKernel:
def __init__(self):
self._use_torch_in_cupy_malloc()
self.stream = Stream(torch.cuda.current_stream(DEVICE).cuda_stream)
self.stream = Stream(torch.cuda.current_stream(get_default_device()).cuda_stream)

@staticmethod
def _torch_alloc(size):
tensor = torch.empty(size, dtype=torch.uint8, device=DEVICE)
tensor = torch.empty(size, dtype=torch.uint8, device=get_default_device())
return cp.cuda.MemoryPointer(
cp.cuda.UnownedMemory(tensor.data_ptr(), size, tensor), 0)

Expand Down
5 changes: 2 additions & 3 deletions torchpq/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import cupy as cp
DEVICE = cp.cuda.Device().id

from .CustomKernel import CustomKernel
from .CustomKernel import Stream

from .default_device import get_default_device, set_default_device

from .GetAddressByIDCuda import GetAddressByIDCuda
from .GetDivByAddressCuda import GetDivByAddressCuda
from .GetDivByAddressV2Cuda import GetDivByAddressV2Cuda
Expand Down
12 changes: 12 additions & 0 deletions torchpq/kernels/default_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import cupy as cp
import torch
__device = cp.cuda.Device().id

def get_default_device():
global __device
return __device

def set_default_device(device_id):
assert device_id < torch.cuda.device_count()
global __device
__device = device_id

0 comments on commit b323af6

Please sign in to comment.