-
Notifications
You must be signed in to change notification settings - Fork 791
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support anyio with a Cargo feature
- Loading branch information
Showing
13 changed files
with
242 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Support anyio with a Cargo feature |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
use std::{task::Poll, thread, time::Duration}; | ||
|
||
use futures::{channel::oneshot, future::poll_fn}; | ||
use pyo3::prelude::*; | ||
|
||
#[pyfunction(signature = (seconds, result = None))] | ||
async fn sleep(seconds: f64, result: Option<PyObject>) -> Option<PyObject> { | ||
if seconds <= 0.0 { | ||
let mut ready = false; | ||
poll_fn(|cx| { | ||
if ready { | ||
return Poll::Ready(()); | ||
} | ||
ready = true; | ||
cx.waker().wake_by_ref(); | ||
Poll::Pending | ||
}) | ||
.await; | ||
} else { | ||
let (tx, rx) = oneshot::channel(); | ||
thread::spawn(move || { | ||
thread::sleep(Duration::from_secs_f64(seconds)); | ||
tx.send(()).unwrap(); | ||
}); | ||
rx.await.unwrap(); | ||
} | ||
result | ||
} | ||
|
||
#[pymodule] | ||
pub fn anyio(m: &Bound<'_, PyModule>) -> PyResult<()> { | ||
m.add_function(wrap_pyfunction!(sleep, m)?)?; | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import asyncio | ||
|
||
from pyo3_pytests.anyio import sleep | ||
import trio | ||
|
||
|
||
def test_asyncio(): | ||
assert asyncio.run(sleep(0)) is None | ||
assert asyncio.run(sleep(0.1, 42)) == 42 | ||
|
||
|
||
def test_trio(): | ||
assert trio.run(sleep, 0) is None | ||
assert trio.run(sleep, 0.1, 42) == 42 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
//! Coroutine implementation using sniffio to select the appropriate implementation, | ||
//! compatible with anyio. | ||
use crate::{ | ||
coroutine::{asyncio::AsyncioWaker, trio::TrioWaker}, | ||
exceptions::PyRuntimeError, | ||
sync::GILOnceCell, | ||
types::PyAnyMethods, | ||
PyObject, PyResult, Python, | ||
}; | ||
|
||
enum AsyncLib { | ||
Asyncio, | ||
Trio, | ||
} | ||
|
||
fn current_async_library(py: Python<'_>) -> PyResult<AsyncLib> { | ||
static CURRENT_ASYNC_LIBRARY: GILOnceCell<Option<PyObject>> = GILOnceCell::new(); | ||
let import = || -> PyResult<_> { | ||
Ok(match py.import("sniffio") { | ||
Ok(module) => Some(module.getattr("current_async_library")?.into()), | ||
Err(_) => None, | ||
}) | ||
}; | ||
let Some(func) = CURRENT_ASYNC_LIBRARY.get_or_try_init(py, import)? else { | ||
return Ok(AsyncLib::Asyncio); | ||
}; | ||
match func.bind(py).call0()?.extract()? { | ||
"asyncio" => Ok(AsyncLib::Asyncio), | ||
"trio" => Ok(AsyncLib::Trio), | ||
rt => Err(PyRuntimeError::new_err(format!("unsupported runtime {rt}"))), | ||
} | ||
} | ||
|
||
/// Sniffio/anyio-compatible coroutine waker. | ||
/// | ||
/// Polling a Rust future calls `sniffio.current_async_library` to select the appropriate | ||
/// implementation, either asyncio or trio. | ||
pub(super) enum AnyioWaker { | ||
/// [`AsyncioWaker`] | ||
Asyncio(AsyncioWaker), | ||
/// [`TrioWaker`] | ||
Trio(TrioWaker), | ||
} | ||
|
||
impl AnyioWaker { | ||
pub(super) fn new(py: Python<'_>) -> PyResult<Self> { | ||
match current_async_library(py)? { | ||
AsyncLib::Asyncio => Ok(Self::Asyncio(AsyncioWaker::new(py)?)), | ||
AsyncLib::Trio => Ok(Self::Trio(TrioWaker::new(py)?)), | ||
} | ||
} | ||
|
||
pub(super) fn yield_(&self, py: Python<'_>) -> PyResult<PyObject> { | ||
match self { | ||
AnyioWaker::Asyncio(w) => w.yield_(py), | ||
AnyioWaker::Trio(w) => w.yield_(py), | ||
} | ||
} | ||
|
||
pub(super) fn yield_waken(py: Python<'_>) -> PyResult<PyObject> { | ||
match current_async_library(py)? { | ||
AsyncLib::Asyncio => AsyncioWaker::yield_waken(py), | ||
AsyncLib::Trio => TrioWaker::yield_waken(py), | ||
} | ||
} | ||
|
||
pub(super) fn wake(&self, py: Python<'_>) -> PyResult<()> { | ||
match self { | ||
AnyioWaker::Asyncio(w) => w.wake(py), | ||
AnyioWaker::Trio(w) => w.wake(py), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
//! Coroutine implementation compatible with trio. | ||
use pyo3_macros::pyfunction; | ||
|
||
use crate::{ | ||
intern, | ||
sync::GILOnceCell, | ||
types::{PyAnyMethods, PyCFunction, PyIterator}, | ||
wrap_pyfunction, Bound, Py, PyAny, PyObject, PyResult, Python, | ||
}; | ||
|
||
struct Trio { | ||
cancel_shielded_checkpoint: PyObject, | ||
current_task: PyObject, | ||
current_trio_token: PyObject, | ||
reschedule: PyObject, | ||
succeeded: PyObject, | ||
wait_task_rescheduled: PyObject, | ||
} | ||
impl Trio { | ||
fn get(py: Python<'_>) -> PyResult<&Self> { | ||
static TRIO: GILOnceCell<Trio> = GILOnceCell::new(); | ||
TRIO.get_or_try_init(py, || { | ||
let module = py.import("trio.lowlevel")?; | ||
Ok(Self { | ||
cancel_shielded_checkpoint: module.getattr("cancel_shielded_checkpoint")?.into(), | ||
current_task: module.getattr("current_task")?.into(), | ||
current_trio_token: module.getattr("current_trio_token")?.into(), | ||
reschedule: module.getattr("reschedule")?.into(), | ||
succeeded: module.getattr("Abort")?.getattr("SUCCEEDED")?.into(), | ||
wait_task_rescheduled: module.getattr("wait_task_rescheduled")?.into(), | ||
}) | ||
}) | ||
} | ||
} | ||
|
||
fn yield_from(coro_func: &Bound<'_, PyAny>) -> PyResult<PyObject> { | ||
PyIterator::from_object(&coro_func.call_method0("__await__")?)? | ||
.next() | ||
.expect("cancel_shielded_checkpoint didn't yield") | ||
.map(Into::into) | ||
} | ||
|
||
/// Asyncio-compatible coroutine waker. | ||
/// | ||
/// Polling a Rust future yields `trio.lowlevel.wait_task_rescheduled()`, while `Waker::wake` | ||
/// reschedule the current task. | ||
pub(super) struct TrioWaker { | ||
task: PyObject, | ||
token: PyObject, | ||
} | ||
|
||
impl TrioWaker { | ||
pub(super) fn new(py: Python<'_>) -> PyResult<Self> { | ||
let trio = Trio::get(py)?; | ||
let task = trio.current_task.call0(py)?; | ||
let token = trio.current_trio_token.call0(py)?; | ||
Ok(Self { task, token }) | ||
} | ||
|
||
pub(super) fn yield_(&self, py: Python<'_>) -> PyResult<PyObject> { | ||
static ABORT_FUNC: GILOnceCell<Py<PyCFunction>> = GILOnceCell::new(); | ||
let abort_func = | ||
ABORT_FUNC.get_or_try_init(py, || wrap_pyfunction!(abort_func, py).map(Into::into))?; | ||
let wait_task_rescheduled = Trio::get(py)? | ||
.wait_task_rescheduled | ||
.call1(py, (abort_func,))?; | ||
yield_from(wait_task_rescheduled.bind(py)) | ||
} | ||
|
||
pub(super) fn yield_waken(py: Python<'_>) -> PyResult<PyObject> { | ||
let checkpoint = Trio::get(py)?.cancel_shielded_checkpoint.call0(py)?; | ||
yield_from(checkpoint.bind(py)) | ||
} | ||
|
||
pub(super) fn wake(&self, py: Python<'_>) -> PyResult<()> { | ||
self.token.call_method1( | ||
py, | ||
intern!(py, "run_sync_soon"), | ||
(&Trio::get(py)?.reschedule, &self.task), | ||
)?; | ||
Ok(()) | ||
} | ||
} | ||
|
||
#[pyfunction(crate = "crate")] | ||
fn abort_func(py: Python<'_>, _arg: &Bound<'_, PyAny>) -> PyResult<PyObject> { | ||
Ok(Trio::get(py)?.succeeded.clone_ref(py)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters