Skip to content

Commit

Permalink
Added max-tasks-per-child parameter. (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius authored Apr 19, 2024
1 parent ed37be4 commit 8f74c63
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
15 changes: 15 additions & 0 deletions taskiq/cli/worker/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class WorkerArgs:
no_propagate_errors: bool = False
max_fails: int = -1
ack_type: AcknowledgeType = AcknowledgeType.WHEN_SAVED
max_tasks_per_child: Optional[int] = None
wait_tasks_timeout: Optional[float] = None

@classmethod
def from_cli(
Expand Down Expand Up @@ -197,6 +199,19 @@ def from_cli(
choices=[ack_type.name.lower() for ack_type in AcknowledgeType],
help="When to acknowledge message.",
)
parser.add_argument(
"--max-tasks-per-child",
type=int,
default=None,
help="Maximum number of tasks to execute per child process.",
)
parser.add_argument(
"--wait-tasks-timeout",
type=float,
default=None,
help="Maximum time to wait for all current tasks "
"to finish before exiting.",
)

namespace = parser.parse_args(args)
# If there are any patterns specified, remove default.
Expand Down
2 changes: 2 additions & 0 deletions taskiq/cli/worker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
max_prefetch=args.max_prefetch,
propagate_exceptions=not args.no_propagate_errors,
ack_type=args.ack_type,
max_tasks_to_execute=args.max_tasks_per_child,
wait_tasks_timeout=args.wait_tasks_timeout,
**receiver_kwargs, # type: ignore
)
loop.run_until_complete(receiver.listen())
Expand Down
14 changes: 14 additions & 0 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(
run_starup: bool = True,
ack_type: Optional[AcknowledgeType] = None,
on_exit: Optional[Callable[["Receiver"], None]] = None,
max_tasks_to_execute: Optional[int] = None,
wait_tasks_timeout: Optional[float] = None,
) -> None:
self.broker = broker
self.executor = executor
Expand All @@ -68,6 +70,8 @@ def __init__(
self.on_exit = on_exit
self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED
self.known_tasks: Set[str] = set()
self.max_tasks_to_execute = max_tasks_to_execute
self.wait_tasks_timeout = wait_tasks_timeout
for task in self.broker.get_all_tasks().values():
self._prepare_task(task.task_name, task.original_func)
self.sem: "Optional[asyncio.Semaphore]" = None
Expand Down Expand Up @@ -342,12 +346,20 @@ async def prefetcher(
:param queue: queue for prefetched data.
"""
fetched_tasks: int = 0
iterator = self.broker.listen()

while True:
try:
await self.sem_prefetch.acquire()
if (
self.max_tasks_to_execute
and fetched_tasks >= self.max_tasks_to_execute
):
logger.info("Max number of tasks executed.")
break
message = await iterator.__anext__()
fetched_tasks += 1
await queue.put(message)
except asyncio.CancelledError:
break
Expand Down Expand Up @@ -389,6 +401,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
self.sem_prefetch.release()
message = await queue.get()
if message is QUEUE_DONE:
logger.info("Waiting for running tasks to complete.")
await asyncio.wait(tasks, timeout=self.wait_tasks_timeout)
break

task = asyncio.create_task(
Expand Down

0 comments on commit 8f74c63

Please sign in to comment.