Skip to content

Commit

Permalink
feat: optimize task state checks
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Dec 12, 2024
1 parent 0e6e9f0 commit 487feda
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions a_sync/_smart.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T
result = await task
```
"""
if self.done():
if self._state != "PENDING":
return self.result() # May raise too.

self._asyncio_future_blocking = True
Expand All @@ -92,7 +92,7 @@ def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T

logger.debug("awaiting %s", self)
yield self # This tells Task to wait for completion.
if not self.done():
if self._state == "PENDING":
raise RuntimeError("await wasn't used with future")
return self.result() # May raise too.

Expand All @@ -113,9 +113,9 @@ def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> int:
See Also:
- :meth:`_waiter_done_cleanup_callback`
"""
if self.done():
if self._state != "PENDING":
return 0
return sum(getattr(waiter, "num_waiters", 1) or 1 for waiter in self._waiters)
return sum(getattr(waiter, "num_waiters", 1) for waiter in self._waiters)

def _waiter_done_cleanup_callback(
self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask"
Expand All @@ -131,7 +131,7 @@ def _waiter_done_cleanup_callback(
Example:
Automatically called when a waiter task completes.
"""
if not self.done():
if self._state == "PENDING":
self._waiters.remove(waiter)

def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None:
Expand Down Expand Up @@ -405,7 +405,7 @@ def shield(
stacklevel=2,
)
inner = asyncio.ensure_future(arg, loop=loop)
if inner.done():
if inner._state != "PENDING":
# Shortcut.
return inner
loop = asyncio.futures._get_loop(inner)
Expand All @@ -431,7 +431,7 @@ def _inner_done_callback(inner):
outer.set_result(inner.result())

def _outer_done_callback(outer):
if not inner.done():
if inner._state == "PENDING":
inner.remove_done_callback(_inner_done_callback)

inner.add_done_callback(_inner_done_callback)
Expand Down

0 comments on commit 487feda

Please sign in to comment.