-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: debuggable Event primitive * feat: CounterLock primitive * feat: add new primitives to main module
- Loading branch information
1 parent
b5bdcc6
commit a150a42
Showing
6 changed files
with
103 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
""" | ||
While not the focus of this lib, this module includes some new primitives and some modified versions of standard asyncio primitives. | ||
""" | ||
|
||
from a_sync.primitives.locks import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
|
||
from functools import cached_property | ||
from logging import Logger, getLogger | ||
|
||
|
||
class _Loggable: | ||
@cached_property | ||
def logger(self) -> Logger: | ||
return getLogger(f"a_sync.{self.__class__.__name__}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
|
||
from a_sync.primitives.locks.counter import CounterLock | ||
from a_sync.primitives.locks.event import Event |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import asyncio | ||
from collections import defaultdict | ||
from typing import Iterable | ||
|
||
|
||
class CounterLock: | ||
""" | ||
A asyncio primative that blocks until the internal counter has reached a specific value. | ||
counter = CounterLock() | ||
A coroutine can now `await counter.wait_for(3)` and it will block until the internal counter >= 3. | ||
Now if some other task executes `counter.value = 5` or `counter.set(5)`, the first coroutine will unblock as 5 >= 3. | ||
The internal counter can only increase. | ||
""" | ||
def __init__(self, start_value: int = 0): | ||
self._value = start_value | ||
self._conditions = defaultdict(asyncio.Event) | ||
self.is_ready = lambda v: self._value >= v | ||
|
||
async def wait_for(self, value: int) -> bool: | ||
if not self.is_ready(value): | ||
await self._conditions[value].wait() | ||
return True | ||
|
||
def set(self, value: int) -> None: | ||
self.value = value | ||
|
||
@property | ||
def value(self) -> int: | ||
return self._value | ||
|
||
@value.setter | ||
def value(self, value: int) -> None: | ||
if value > self._value: | ||
self._value = value | ||
ready = [ | ||
self._conditions.pop(key) | ||
for key in list(self._conditions.keys()) | ||
if key <= self._value | ||
] | ||
for event in ready: | ||
event.set() | ||
elif value < self._value: | ||
raise ValueError("You cannot decrease the value.") | ||
|
||
class CounterLockCluster: | ||
""" | ||
An asyncio primitive that represents 2 or more CounterLock objects. | ||
`wait_for(i)` will block until the value of all CounterLock objects is >= i. | ||
""" | ||
def __init__(self, counter_locks: Iterable[CounterLock]) -> None: | ||
self.locks = list(counter_locks) | ||
|
||
async def wait_for(self, value: int) -> bool: | ||
await asyncio.gather(*[counter_lock.wait_for(value) for counter_lock in self.locks]) | ||
return True | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
|
||
import asyncio | ||
from a_sync.primitives._loggable import _Loggable | ||
|
||
class Event(asyncio.Event, _Loggable): | ||
"""asyncio.Event but with some additional debug logging to help detect deadlocks.""" | ||
def __init__(self): | ||
self._task = None | ||
self._counter = 0 | ||
super().__init__() | ||
|
||
async def wait(self) -> bool: | ||
if self.is_set(): | ||
return True | ||
if self._task is None: | ||
self._task = asyncio.create_task(self._debug_helper()) | ||
return await super().wait() | ||
|
||
async def _debug_helper(self) -> None: | ||
while not self.is_set(): | ||
self.logger.debug(f"Waiting for {self}") | ||
await asyncio.sleep(5) | ||
self._task = None |