-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] Add redis lock extension (#11)
- Loading branch information
1 parent
7bb829d
commit 2963757
Showing
4 changed files
with
159 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import logging | ||
from typing import TypeVar, Coroutine, Any | ||
|
||
from redis.asyncio import Redis | ||
from redis.asyncio.lock import Lock as RedisLock | ||
|
||
from fixcloudutils.asyncio import stop_running_task | ||
|
||
log = logging.getLogger(__file__) | ||
T = TypeVar("T") | ||
|
||
|
||
class Lock: | ||
# noinspection PyUnresolvedReferences | ||
""" | ||
Redis based Lock extension. | ||
You cannot use this lock as context manager, but pass the function to be performed. | ||
The lock will be created/released when the action is performed. | ||
The action is canceled if the lock cannot be extended. | ||
Example: | ||
>>> async def perform_locked_action() -> str: | ||
>>> print("Acquired the lock!") | ||
>>> # do stuff here that needs controlled access | ||
>>> return "done" | ||
>>> async def main() -> None: | ||
>>> redis: Redis = ... | ||
>>> lock = Lock(redis, "test_lock", 5) | ||
>>> result = await lock.with_lock(perform_locked_action()) | ||
""" | ||
|
||
def __init__(self, redis: Redis, lock_name: str, timeout: float): | ||
self.redis = redis | ||
self.lock_name = "redlock__" + lock_name | ||
self.timeout = timeout | ||
|
||
async def with_lock(self, coro: Coroutine[T, None, Any]) -> T: | ||
""" | ||
Use this method in a situation where the time it takes for the coroutine to run is unknown. | ||
This method will extend the lock time as long as the coroutine is running (every half of the auto_release_time). | ||
If the lock cannot be extended, the coroutine will be | ||
stopped and an ExtendUnlockedLock exception will be raised. | ||
:param coro: The coroutine to execute. | ||
:return: The result of the coroutine | ||
""" | ||
|
||
async with self.lock() as lock: | ||
|
||
async def extend_lock() -> None: | ||
while True: | ||
extend_time = self.timeout / 2 | ||
await asyncio.sleep(extend_time) | ||
log.debug(f"Extend the lock {self.lock_name} for {self.timeout} seconds.") | ||
await lock.extend(self.timeout) # will throw in case the lock is not owned anymore | ||
|
||
# The extend_lock task will never return a result but only an exception | ||
# So we can take the first done, which will be either: | ||
# - the exception from extend_lock | ||
# - the exception from coro | ||
# - the result from coro | ||
done, pending = await asyncio.wait( | ||
[asyncio.create_task(extend_lock()), asyncio.create_task(coro)], return_when=asyncio.FIRST_COMPLETED | ||
) | ||
for task in pending: | ||
await stop_running_task(task) | ||
for task in done: | ||
return task.result() # type: ignore | ||
raise Exception("You should never come here!") | ||
|
||
def lock(self) -> RedisLock: | ||
return self.redis.lock(name=self.lock_name, timeout=self.timeout) | ||
|
||
|
||
if __name__ == "__main__": | ||
import time | ||
|
||
shift = 1694685084 | ||
|
||
def show(*args: Any) -> None: | ||
t = int((time.time() - shift) * 1000) | ||
print(f"[{t}] ", *args) | ||
|
||
async def simple_check() -> None: | ||
redis = Redis.from_url("redis://localhost:6379") | ||
lock = Lock(redis, "test_lock", 5) | ||
|
||
async def perform_action() -> str: | ||
show("Acquired the lock!") | ||
for i in range(11): | ||
show("Performing work. Still locked.") | ||
await asyncio.sleep(1) | ||
return "{int(time.time() * 1000)} done" | ||
|
||
while True: | ||
try: | ||
result = await lock.with_lock(perform_action()) | ||
show("Lock released. Result of locked action: ", result) | ||
except Exception as ex: | ||
show("GOT exception", ex) | ||
finally: | ||
await asyncio.sleep(1) | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
asyncio.run(simple_check()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import asyncio | ||
import os | ||
|
||
import pytest | ||
from redis.asyncio import Redis | ||
|
||
from fixcloudutils.redis.lock import Lock | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.skipif(os.environ.get("REDIS_RUNNING") is None, reason="Redis is not running") | ||
async def test_lock(redis: Redis) -> None: | ||
holy_grail = [0] # one task should only modify the holy grail at a time | ||
cond = asyncio.Event() # mark the beginning of the test | ||
number = 0 # counts the number of concurrent tasks | ||
|
||
async def try_with_lock(num: int) -> str: | ||
nonlocal number | ||
|
||
async def perform_locked_action() -> str: | ||
print(f"[{num}] performing action") | ||
nonlocal holy_grail | ||
holy_grail[0] += 1 | ||
holy_grail.append(num) | ||
assert len(holy_grail) == 2 | ||
assert holy_grail[-1] == num | ||
assert len(holy_grail) == 2 | ||
holy_grail.pop() | ||
print(f"[{num}] performing action done") | ||
return "done" | ||
|
||
lock = Lock(redis, "test_lock", 5) | ||
number += 1 | ||
await cond.wait() # wait for the test driver to start | ||
return await lock.with_lock(perform_locked_action()) | ||
|
||
tasks = [asyncio.create_task(try_with_lock(num)) for num in range(10)] | ||
# wait for all tasks to start | ||
while number < 10: | ||
await asyncio.sleep(0.1) | ||
cond.set() | ||
await asyncio.gather(*tasks) |