diff --git a/python/rmm/_cuda/stream.pyx b/python/rmm/_cuda/stream.pyx index 0f6c5ab19..541524124 100644 --- a/python/rmm/_cuda/stream.pyx +++ b/python/rmm/_cuda/stream.pyx @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + from libc.stdint cimport uintptr_t from libcpp cimport bool @@ -117,7 +119,10 @@ cdef class Stream: self._cuda_stream, self._owner = stream._cuda_stream, stream._owner -DEFAULT_STREAM = Stream._from_cudaStream_t(cuda_stream_default.value()) +if int(os.environ.get("RMM_PER_THREAD_DEFAULT_STREAM", "0")) != 0: + DEFAULT_STREAM = Stream._from_cudaStream_t(cuda_stream_per_thread.value()) +else: + DEFAULT_STREAM = Stream._from_cudaStream_t(cuda_stream_default.value()) LEGACY_DEFAULT_STREAM = Stream._from_cudaStream_t(cuda_stream_legacy.value()) PER_THREAD_DEFAULT_STREAM = Stream._from_cudaStream_t( cuda_stream_per_thread.value()