Skip to content

Commit

Permalink
[feat] Add redis lock extension (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
aquamatthias authored Sep 15, 2023
1 parent 7bb829d commit 2963757
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 2 deletions.
108 changes: 108 additions & 0 deletions fixcloudutils/redis/lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

import asyncio
import logging
from typing import TypeVar, Coroutine, Any

from redis.asyncio import Redis
from redis.asyncio.lock import Lock as RedisLock

from fixcloudutils.asyncio import stop_running_task

log = logging.getLogger(__file__)
T = TypeVar("T")


class Lock:
# noinspection PyUnresolvedReferences
"""
Redis based Lock extension.
You cannot use this lock as context manager, but pass the function to be performed.
The lock will be created/released when the action is performed.
The action is canceled if the lock cannot be extended.
Example:
>>> async def perform_locked_action() -> str:
>>> print("Acquired the lock!")
>>> # do stuff here that needs controlled access
>>> return "done"
>>> async def main() -> None:
>>> redis: Redis = ...
>>> lock = Lock(redis, "test_lock", 5)
>>> result = await lock.with_lock(perform_locked_action())
"""

def __init__(self, redis: Redis, lock_name: str, timeout: float):
self.redis = redis
self.lock_name = "redlock__" + lock_name
self.timeout = timeout

async def with_lock(self, coro: Coroutine[T, None, Any]) -> T:
"""
Use this method in a situation where the time it takes for the coroutine to run is unknown.
This method will extend the lock time as long as the coroutine is running (every half of the auto_release_time).
If the lock cannot be extended, the coroutine will be
stopped and an ExtendUnlockedLock exception will be raised.
:param coro: The coroutine to execute.
:return: The result of the coroutine
"""

async with self.lock() as lock:

async def extend_lock() -> None:
while True:
extend_time = self.timeout / 2
await asyncio.sleep(extend_time)
log.debug(f"Extend the lock {self.lock_name} for {self.timeout} seconds.")
await lock.extend(self.timeout) # will throw in case the lock is not owned anymore

# The extend_lock task will never return a result but only an exception
# So we can take the first done, which will be either:
# - the exception from extend_lock
# - the exception from coro
# - the result from coro
done, pending = await asyncio.wait(
[asyncio.create_task(extend_lock()), asyncio.create_task(coro)], return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
await stop_running_task(task)
for task in done:
return task.result() # type: ignore
raise Exception("You should never come here!")

def lock(self) -> RedisLock:
return self.redis.lock(name=self.lock_name, timeout=self.timeout)


if __name__ == "__main__":
import time

shift = 1694685084

def show(*args: Any) -> None:
t = int((time.time() - shift) * 1000)
print(f"[{t}] ", *args)

async def simple_check() -> None:
redis = Redis.from_url("redis://localhost:6379")
lock = Lock(redis, "test_lock", 5)

async def perform_action() -> str:
show("Acquired the lock!")
for i in range(11):
show("Performing work. Still locked.")
await asyncio.sleep(1)
return "{int(time.time() * 1000)} done"

while True:
try:
result = await lock.with_lock(perform_action())
show("Lock released. Result of locked action: ", result)
except Exception as ex:
show("GOT exception", ex)
finally:
await asyncio.sleep(1)

logging.basicConfig(level=logging.DEBUG)
asyncio.run(simple_check())
9 changes: 8 additions & 1 deletion fixcloudutils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import uuid
from datetime import datetime, timezone
from typing import Optional, TypeVar, Union, List, Any

Expand Down Expand Up @@ -95,3 +95,10 @@ def at_idx(current: JsonElement, idx: int) -> JsonElement:
return result

return at_idx(element, 0)


def uuid_str(from_object: Optional[Any] = None) -> str:
if from_object:
return str(uuid.uuid5(uuid.NAMESPACE_DNS, from_object))
else:
return str(uuid.uuid1())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fixcloudutils"
version = "1.4.0"
version = "1.5.0"
authors = [{ name = "Some Engineering Inc." }]
description = "Utilities for fixcloud."
license = { file = "LICENSE" }
Expand Down
42 changes: 42 additions & 0 deletions tests/lock_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
import os

import pytest
from redis.asyncio import Redis

from fixcloudutils.redis.lock import Lock


@pytest.mark.asyncio
@pytest.mark.skipif(os.environ.get("REDIS_RUNNING") is None, reason="Redis is not running")
async def test_lock(redis: Redis) -> None:
holy_grail = [0] # one task should only modify the holy grail at a time
cond = asyncio.Event() # mark the beginning of the test
number = 0 # counts the number of concurrent tasks

async def try_with_lock(num: int) -> str:
nonlocal number

async def perform_locked_action() -> str:
print(f"[{num}] performing action")
nonlocal holy_grail
holy_grail[0] += 1
holy_grail.append(num)
assert len(holy_grail) == 2
assert holy_grail[-1] == num
assert len(holy_grail) == 2
holy_grail.pop()
print(f"[{num}] performing action done")
return "done"

lock = Lock(redis, "test_lock", 5)
number += 1
await cond.wait() # wait for the test driver to start
return await lock.with_lock(perform_locked_action())

tasks = [asyncio.create_task(try_with_lock(num)) for num in range(10)]
# wait for all tasks to start
while number < 10:
await asyncio.sleep(0.1)
cond.set()
await asyncio.gather(*tasks)

0 comments on commit 2963757

Please sign in to comment.