diff --git a/a_sync/_smart.py b/a_sync/_smart.py index e61a215e..b6e901d7 100644 --- a/a_sync/_smart.py +++ b/a_sync/_smart.py @@ -6,7 +6,6 @@ """ import asyncio -import contextvars import logging import warnings import weakref @@ -198,9 +197,7 @@ def __init__( if key: self._key = key self._waiters = weakref.WeakSet() - self._callbacks.append( - (SmartFuture._self_done_cleanup_callback, contextvars.copy_context()) - ) + self.add_done_callback(SmartFuture._self_done_cleanup_callback) def __repr__(self): return f"<{type(self).__name__} key={self._key} waiters={self.num_waiters} {self._state}>" @@ -301,9 +298,7 @@ def __init__( """ asyncio.Task.__init__(self, coro, loop=loop, name=name) self._waiters: Set["asyncio.Task[T]"] = set() - self._callbacks.append( - (SmartTask._self_done_cleanup_callback, contextvars.copy_context()) - ) + self.add_done_callback(SmartTask._self_done_cleanup_callback) def smart_task_factory( diff --git a/tests/test_smart.py b/tests/test_smart.py new file mode 100644 index 00000000..e342593b --- /dev/null +++ b/tests/test_smart.py @@ -0,0 +1,15 @@ +import asyncio +import pytest + +from a_sync._smart import SmartTask + + +@pytest.mark.asyncio_cooperative +async def test_smart_task_await(): + await SmartTask(asyncio.sleep(0.1), loop=None) + + +@pytest.mark.asyncio_cooperative +async def test_smart_task_name(): + t = SmartTask(asyncio.sleep(0.1), loop=None, name="test") + assert t.get_name() == "test"