Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Make the rmm._cuda.stream.Stream a part of the public API #1770

Open
Matt711 opened this issue Dec 18, 2024 · 1 comment
Open

[FEA] Make the rmm._cuda.stream.Stream a part of the public API #1770

Matt711 opened this issue Dec 18, 2024 · 1 comment
Assignees
Labels
feature request New feature or request

Comments

@Matt711
Copy link
Contributor

Matt711 commented Dec 18, 2024

Is your feature request related to a problem? Please describe.
A clear and concise description of what the problem is. Ex. I wish I could use RMM to do [...]
We're starting to expose streams to pylibcudf APIs. xref rapidsai/cudf#15163 and rapidsai/cudf#17620). Can we make Stream part of the public API.

Describe the solution you'd like
Migrate rmm._cuda.stream.Stream to rmm.pylibrmm.stream.Stream. We would deprecate rmm._cuda.stream.Stream for 25.02 and remove it in 25.04.

@Matt711 Matt711 added the feature request New feature or request label Dec 18, 2024
@Matt711 Matt711 self-assigned this Dec 18, 2024
@Matt711 Matt711 changed the title [FEA] Make the rmm._cuda.stream.Stream apart of the public API [FEA] Make the rmm._cuda.stream.Stream a part of the public API Dec 18, 2024
@vyasr
Copy link
Contributor

vyasr commented Dec 19, 2024

cuda-python is proposing a __cuda_stream__ protocol that any stream type could implement to support interoperability. I would suggest that we basically both implement that protocol and accept any object supporting that protocol. Concretely, I would replace the current stream constructor with something like the following:

cdef class Stream:
    def __init__(self, obj=None):
        self._init_from_stream(obj)

    @singledispatchmethod
    def _init_from_stream(self, stream):
        try:
            protocol = getattr(obj, "__cuda_stream__")
        except AttributeError:
            raise ValueError(
                "stream must be a supported stream type or implement the __cuda_stream__ protocol"
            )
        if protocol[0] != 0:
            raise ValueError("Only protocol version 0 is supported")
                                                                                 
        self._cuda_stream = stream
        self.owner = stream

    @_from_stream.register()
    def _from_stream(self, Stream):
        self._cuda_stream, self._owner = stream._cuda_stream, stream._owner
                                                                                                                                             
    @_from_stream.register()                                                                                        
    def _from_stream(self, stream: numba.cuda.cudadrv.driver.Stream):
        # Whatever special handling for numba if needed, but ideally could just use the protocol

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
Status: To-do
Development

No branches or pull requests

2 participants