diff --git a/src/retsu/core.py b/src/retsu/core.py index 2c5c864..cc83c11 100644 --- a/src/retsu/core.py +++ b/src/retsu/core.py @@ -226,47 +226,38 @@ def __init__( def acquire(self, task_id: str) -> bool: """Try to acquire a semaphore slot and ensure FIFO order.""" task_bid = task_id.encode("utf8") - # logging.info( - # f"[Semaphore] Task {task_id} is attempting to acquire a slot." - # ) - # Add task to queue queue_name = f"{self.key}_queue" + + # Add task to the queue self.redis_client.rpush(queue_name, task_id) while True: - # Get the current task at the front of the queue - current_task_id_tmp = self.redis_client.lindex(queue_name, 0) - - if current_task_id_tmp is None: - time.sleep(0.1) # Polling interval to wait for the slot - continue - - current_task_id = cast(bytes, current_task_id_tmp) - - # logging.info( - # f"[Semaphore] Current task at front: {current_task_id}" - # ) - - # If the current task is this one, check if a semaphore slot is - # available - if current_task_id != task_bid: - time.sleep(0.1) # Polling interval to wait for the slot - continue - - count_tmp = self.redis_client.get(self.key) - current_count = int(cast(bytes, count_tmp) or 0) - # logging.info(f"[Semaphore] Current count: {current_count}") - - if current_count < self.max_concurrent_tasks: - # logging.info(f"[Semaphore] Task {task_id} acquired a slot.") + # Get the list of current tasks in the queue + queue_tasks = self.redis_client.lrange(queue_name, 0, -1) + count_tmp = cast(bytes, self.redis_client.get(self.key) or b"0") + current_count = int(count_tmp) + + # Check if the task is in the first `max_concurrent_tasks` + # in the queue + # mypy: Item "Awaitable[List[Any]]" of + # "Union[Awaitable[List[Any]], List[Any]]" + # has no attribute "index" + task_position = queue_tasks.index(task_bid) # type: ignore + + if ( + task_position < self.max_concurrent_tasks + and current_count < self.max_concurrent_tasks + ): + # If a slot is available and the task is within the + # allowed concurrent limit self.redis_client.incr(self.key) return True - # keep waiting until any slot is available - time.sleep(0.01) + # If no slot is available or task is not in the allowed + # concurrent tasks, keep waiting + time.sleep(0.1) def release(self) -> None: """Release a semaphore slot and remove the task from the queue.""" - # logging.info(f"[Semaphore] Releasing a slot.") self.redis_client.decr(self.key) self.redis_client.lpop(f"{self.key}_queue") diff --git a/tests/test_task_celery_wrapup.py b/tests/test_task_celery_wrapup.py index 348633e..34c2ce5 100644 --- a/tests/test_task_celery_wrapup.py +++ b/tests/test_task_celery_wrapup.py @@ -6,39 +6,43 @@ from celery import Task -from .celery_tasks import task_sequence_get_time +from .celery_tasks import ( + task_random_get_time, + task_sequence_get_time, +) -# def test_task_random_get_time() -> None: -# """Test task_random_get_time.""" -# results: dict[int, float] = {} -# tasks: list[Task] = [] -# start_time = time.time() -# for i in range(10): -# task_promise = task_random_get_time.s( -# request_id=i, start_time=start_time -# ) -# tasks.append(task_promise.apply_async()) +def test_task_random_get_time() -> None: + """Test task_random_get_time.""" + results: dict[int, float] = {} + tasks: list[Task] = [] + start_time = time.time() -# for i in range(10): -# task = tasks[i] -# task_id, result = task.get(timeout=10) -# assert i == task_id -# results[task_id] = result + for i in range(10): + task_promise = task_random_get_time.s( + request_id=i, start_time=start_time + ) + tasks.append(task_promise.apply_async()) -# previous_time = results[0] -# previous_id = 0 -# tol = 0.2 -# for i in range(10): -# current_time = results[i] -# diff = abs(current_time - previous_time) -# # print( -# # f"task {previous_id}-{i}, diff: {diff}, " -# # f"expected: {diff_expected}" -# # ) -# assert diff - tol < 5, f"[EE] Task {previous_id}-{i}" -# previous_time = current_time -# previous_id = i + for i in range(10): + task = tasks[i] + task_id, result = task.get(timeout=10) + assert i == task_id + results[task_id] = result + + previous_time = results[0] + previous_id = 0 + tol = 0.2 + for i in range(10): + current_time = results[i] + diff = abs(current_time - previous_time) + # print( + # f"task {previous_id}-{i}, diff: {diff}, " + # f"expected: {diff_expected}" + # ) + assert diff - tol < 5, f"[EE] Task {previous_id}-{i}" + previous_time = current_time + previous_id = i def test_task_sequence_get_time() -> None: @@ -79,10 +83,10 @@ def test_task_sequence_get_time() -> None: current_time = results[i] diff = current_time - previous_time diff_expected = diffs[i] - print( - f"task {previous_id}-{i}, diff: {diff}, " - f"expected: {diff_expected}" - ) + # print( + # f"task {previous_id}-{i}, diff: {diff}, " + # f"expected: {diff_expected}" + # ) assert diff >= diff_expected - tol, f"[EE] Task {previous_id}-{i}" assert diff <= diff_expected + tol, f"[EE] Task {previous_id}-{i}" previous_time = current_time