Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
fix get_stream issue
Browse files Browse the repository at this point in the history
  • Loading branch information
dongxuy04 committed Nov 3, 2023
1 parent 99682af commit e848bc6
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,18 @@
from typing import Union
from .utils import wholememory_dtype_to_torch_dtype, torch_dtype_to_wholememory_dtype

default_cuda_stream_int_ptr = None
default_wholegraph_env_context = None
torch_cpp_ext_loaded = False
torch_cpp_ext_lib = None


def get_stream(use_default=True):
global default_cuda_stream_int_ptr
def get_stream():
cuda_stream_int_ptr = None
if default_cuda_stream_int_ptr is None or not use_default:
cuda_stream = torch.cuda.current_stream()._as_parameter_
if cuda_stream.value is not None:
cuda_stream_int_ptr = cuda_stream.value
else:
cuda_stream_int_ptr = int(0)
if use_default:
default_cuda_stream_int_ptr = cuda_stream_int_ptr
cuda_stream = torch.cuda.current_stream()._as_parameter_
if cuda_stream.value is not None:
cuda_stream_int_ptr = cuda_stream.value
else:
cuda_stream_int_ptr = default_cuda_stream_int_ptr
cuda_stream_int_ptr = int(0)
return cuda_stream_int_ptr


Expand Down

0 comments on commit e848bc6

Please sign in to comment.