Skip to content

Commit

Permalink
add sequencial semaphore latest version
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab committed Aug 18, 2024
1 parent 420b64f commit 6206f3f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 65 deletions.
55 changes: 23 additions & 32 deletions src/retsu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,47 +226,38 @@ def __init__(
def acquire(self, task_id: str) -> bool:
"""Try to acquire a semaphore slot and ensure FIFO order."""
task_bid = task_id.encode("utf8")
# logging.info(
# f"[Semaphore] Task {task_id} is attempting to acquire a slot."
# )
# Add task to queue
queue_name = f"{self.key}_queue"

# Add task to the queue
self.redis_client.rpush(queue_name, task_id)

while True:
# Get the current task at the front of the queue
current_task_id_tmp = self.redis_client.lindex(queue_name, 0)

if current_task_id_tmp is None:
time.sleep(0.1) # Polling interval to wait for the slot
continue

current_task_id = cast(bytes, current_task_id_tmp)

# logging.info(
# f"[Semaphore] Current task at front: {current_task_id}"
# )

# If the current task is this one, check if a semaphore slot is
# available
if current_task_id != task_bid:
time.sleep(0.1) # Polling interval to wait for the slot
continue

count_tmp = self.redis_client.get(self.key)
current_count = int(cast(bytes, count_tmp) or 0)
# logging.info(f"[Semaphore] Current count: {current_count}")

if current_count < self.max_concurrent_tasks:
# logging.info(f"[Semaphore] Task {task_id} acquired a slot.")
# Get the list of current tasks in the queue
queue_tasks = self.redis_client.lrange(queue_name, 0, -1)
count_tmp = cast(bytes, self.redis_client.get(self.key) or b"0")
current_count = int(count_tmp)

# Check if the task is in the first `max_concurrent_tasks`
# in the queue
# mypy: Item "Awaitable[List[Any]]" of
# "Union[Awaitable[List[Any]], List[Any]]"
# has no attribute "index"
task_position = queue_tasks.index(task_bid) # type: ignore

if (
task_position < self.max_concurrent_tasks
and current_count < self.max_concurrent_tasks
):
# If a slot is available and the task is within the
# allowed concurrent limit
self.redis_client.incr(self.key)
return True

# keep waiting until any slot is available
time.sleep(0.01)
# If no slot is available or task is not in the allowed
# concurrent tasks, keep waiting
time.sleep(0.1)

def release(self) -> None:
"""Release a semaphore slot and remove the task from the queue."""
# logging.info(f"[Semaphore] Releasing a slot.")
self.redis_client.decr(self.key)
self.redis_client.lpop(f"{self.key}_queue")
70 changes: 37 additions & 33 deletions tests/test_task_celery_wrapup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,43 @@

from celery import Task

from .celery_tasks import task_sequence_get_time
from .celery_tasks import (
task_random_get_time,
task_sequence_get_time,
)

# def test_task_random_get_time() -> None:
# """Test task_random_get_time."""
# results: dict[int, float] = {}
# tasks: list[Task] = []
# start_time = time.time()

# for i in range(10):
# task_promise = task_random_get_time.s(
# request_id=i, start_time=start_time
# )
# tasks.append(task_promise.apply_async())
def test_task_random_get_time() -> None:
"""Test task_random_get_time."""
results: dict[int, float] = {}
tasks: list[Task] = []
start_time = time.time()

# for i in range(10):
# task = tasks[i]
# task_id, result = task.get(timeout=10)
# assert i == task_id
# results[task_id] = result
for i in range(10):
task_promise = task_random_get_time.s(
request_id=i, start_time=start_time
)
tasks.append(task_promise.apply_async())

# previous_time = results[0]
# previous_id = 0
# tol = 0.2
# for i in range(10):
# current_time = results[i]
# diff = abs(current_time - previous_time)
# # print(
# # f"task {previous_id}-{i}, diff: {diff}, "
# # f"expected: {diff_expected}"
# # )
# assert diff - tol < 5, f"[EE] Task {previous_id}-{i}"
# previous_time = current_time
# previous_id = i
for i in range(10):
task = tasks[i]
task_id, result = task.get(timeout=10)
assert i == task_id
results[task_id] = result

previous_time = results[0]
previous_id = 0
tol = 0.2
for i in range(10):
current_time = results[i]
diff = abs(current_time - previous_time)
# print(
# f"task {previous_id}-{i}, diff: {diff}, "
# f"expected: {diff_expected}"
# )
assert diff - tol < 5, f"[EE] Task {previous_id}-{i}"
previous_time = current_time
previous_id = i


def test_task_sequence_get_time() -> None:
Expand Down Expand Up @@ -79,10 +83,10 @@ def test_task_sequence_get_time() -> None:
current_time = results[i]
diff = current_time - previous_time
diff_expected = diffs[i]
print(
f"task {previous_id}-{i}, diff: {diff}, "
f"expected: {diff_expected}"
)
# print(
# f"task {previous_id}-{i}, diff: {diff}, "
# f"expected: {diff_expected}"
# )
assert diff >= diff_expected - tol, f"[EE] Task {previous_id}-{i}"
assert diff <= diff_expected + tol, f"[EE] Task {previous_id}-{i}"
previous_time = current_time
Expand Down

0 comments on commit 6206f3f

Please sign in to comment.