diff --git a/docs/howto/advanced/cancellation.md b/docs/howto/advanced/cancellation.md index d5d924bb7..4726294ee 100644 --- a/docs/howto/advanced/cancellation.md +++ b/docs/howto/advanced/cancellation.md @@ -37,9 +37,19 @@ app.job_manager.cancel_job_by_id(33, abort=True) await app.job_manager.cancel_job_by_id_async(33, abort=True) ``` +Behind the scenes, the worker receives a Postgres notification every time a job is requested to abort, (unless `listen_notify=False`). + +The worker also polls (respecting `polling_interval`) the database for abortion requests, as long as the worker is running at least one job (in the absence of running job, there is nothing to abort). + +:::{note} +When a job is requested to abort and that job fails, it will not be retried (regardless of the retry strategy). +::: + ## Handle an abortion request inside the task -In our task, we can check (for example, periodically) if the task should be +### Sync tasks + +In a sync task, we can check (for example, periodically) if the task should be aborted. If we want to respect that abortion request (we don't have to), we raise a `JobAborted` error. Any message passed to `JobAborted` (e.g. `raise JobAborted("custom message")`) will end up in the logs. @@ -53,10 +63,31 @@ def my_task(context): do_something_expensive() ``` -Behind the scenes, the worker receives a Postgres notification every time a job is requested to abort, (unless `listen_notify=False`). +### Async tasks -The worker also polls (respecting `polling_interval`) the database for abortion requests, as long as the worker is running at least one job (in the absence of running job, there is nothing to abort). +For async tasks (coroutines), they are cancelled via the [asyncio cancellation](https://docs.python.org/3/library/asyncio-task.html#task-cancellation) mechasnism. -:::{note} -When a job is requested to abort and that job fails, it will not be retried (regardless of the retry strategy). -::: +```python +@app.task() +async def my_task(): + do_something_synchronous() + # if the job is aborted while it waits for do_something to complete, asyncio.CancelledError will be raised here + await do_something() +``` + +If you want to have some custom behavior at cancellation time, use a combination of [shielding](https://docs.python.org/3/library/asyncio-task.html#shielding-from-cancellation) and capturing `except asyncio.CancelledError`. + +```python +@app.task() +async def my_task(): + try: + important_task = asyncio.create_task(something_important()) + # shield something_important from being cancelled + await asyncio.shield(important_task) + except asyncio.CancelledError: + # capture the error and waits for something important to complete + await important_task + # raise if the job should be marked as aborted, or swallow CancelledError if the job should be + # marked as suceeeded + raise +``` diff --git a/procrastinate/worker.py b/procrastinate/worker.py index 6c4fd2d86..52e700bbf 100644 --- a/procrastinate/worker.py +++ b/procrastinate/worker.py @@ -119,6 +119,7 @@ async def _persist_job_status( job: jobs.Job, status: jobs.Status, retry_decision: retry.RetryDecision | None, + context: job_context.JobContext, ): if retry_decision: await self.app.job_manager.retry_job( @@ -138,6 +139,13 @@ async def _persist_job_status( job=job, status=status, delete_job=delete_job ) + self._job_ids_to_abort.discard(job.id) + + self.logger.debug( + f"Acknowledged job completion {job.call_string}", + extra=self._log_extra(action="finish_task", context=context, status=status), + ) + def _log_job_outcome( self, status: jobs.Status, @@ -257,18 +265,20 @@ async def ensure_async() -> Callable[..., Awaitable]: job_retry=job_retry, exc_info=exc_info, ) - await self._persist_job_status( - job=job, status=status, retry_decision=retry_decision - ) - - self._job_ids_to_abort.discard(job.id) - self.logger.debug( - f"Acknowledged job completion {job.call_string}", - extra=self._log_extra( - action="finish_task", context=context, status=status - ), + persist_job_status_task = asyncio.create_task( + self._persist_job_status( + job=job, + status=status, + retry_decision=retry_decision, + context=context, + ) ) + try: + await asyncio.shield(persist_job_status_task) + except asyncio.CancelledError: + await persist_job_status_task + raise async def _fetch_and_process_jobs(self): """Fetch and process jobs until there is no job left or asked to stop""" @@ -372,8 +382,29 @@ async def _poll_jobs_to_abort(self): def _handle_abort_jobs_requested(self, job_ids: Iterable[int]): running_job_ids = {c.job.id for c in self._running_jobs.values() if c.job.id} - self._job_ids_to_abort |= set(job_ids) - self._job_ids_to_abort &= running_job_ids + new_job_ids_to_abort = (running_job_ids & set(job_ids)) - self._job_ids_to_abort + + for process_job_task, context in self._running_jobs.items(): + if context.job.id in new_job_ids_to_abort: + self._abort_job(process_job_task, context) + + def _abort_job( + self, process_job_task: asyncio.Task, context: job_context.JobContext + ): + self._job_ids_to_abort.add(context.job.id) + + log_message: str + if not context.task: + log_message = "Received a request to abort a job but the job has no associated task. No action to perform" + elif not asyncio.iscoroutinefunction(context.task.func): + log_message = "Received a request to abort a synchronous job. Job is responsible for aborting by checking context.should_abort" + else: + log_message = "Received a request to abort an asynchronous job. Cancelling asyncio task" + process_job_task.cancel() + + self.logger.debug( + log_message, extra=self._log_extra(action="abort_job", context=context) + ) async def _shutdown(self, side_tasks: list[asyncio.Task]): """ diff --git a/tests/acceptance/test_async.py b/tests/acceptance/test_async.py index ce4c9109a..c854917d3 100644 --- a/tests/acceptance/test_async.py +++ b/tests/acceptance/test_async.py @@ -110,12 +110,9 @@ def example_task(): @pytest.mark.parametrize("mode", ["listen", "poll"]) async def test_abort_async_task(async_app: app_module.App, mode): - @async_app.task(queue="default", name="task1", pass_context=True) - async def task1(context): - while True: - await asyncio.sleep(0.02) - if context.should_abort(): - raise JobAborted + @async_app.task(queue="default", name="task1") + async def task1(): + await asyncio.sleep(0.5) job_id = await task1.defer_async() diff --git a/tests/unit/test_worker.py b/tests/unit/test_worker.py index 72a6101b3..f9913836c 100644 --- a/tests/unit/test_worker.py +++ b/tests/unit/test_worker.py @@ -5,6 +5,7 @@ from typing import cast import pytest +from pytest_mock import MockerFixture from procrastinate.app import App from procrastinate.exceptions import JobAborted @@ -509,7 +510,7 @@ def task_func(a, b): assert "to retry" not in record.message -async def test_run_job_aborted(app: App, worker, caplog): +async def test_run_job_raising_job_aborted(app: App, worker, caplog): caplog.set_level("INFO") @app.task(queue="yay", name="task_func") @@ -530,6 +531,73 @@ async def task_func(): assert "Aborted" in record.message +async def test_abort_async_job(app: App, worker): + @app.task(queue="yay", name="task_func") + async def task_func(): + await asyncio.sleep(0.2) + + job_id = await task_func.defer_async() + + await start_worker(worker) + await app.job_manager.cancel_job_by_id_async(job_id, abort=True) + await asyncio.sleep(0.01) + status = await app.job_manager.get_job_status_async(job_id) + assert status == Status.ABORTED + + +async def test_abort_async_job_while_finishing(app: App, worker, mocker: MockerFixture): + """ + Tests that aborting a job after that job completes but before the job status is updated + does not prevent the job status from being updated + """ + connector = cast(InMemoryConnector, app.connector) + original_finish_job_run = connector.finish_job_run + + complete_finish_job_event = asyncio.Event() + + async def delayed_finish_job_run(**arguments): + await complete_finish_job_event.wait() + return await original_finish_job_run(**arguments) + + connector.finish_job_run = mocker.AsyncMock(name="finish_job_run") + connector.finish_job_run.side_effect = delayed_finish_job_run + + @app.task(queue="yay", name="task_func") + async def task_func(): + pass + + job_id = await task_func.defer_async() + + await start_worker(worker) + await app.job_manager.cancel_job_by_id_async(job_id, abort=True) + await asyncio.sleep(0.01) + complete_finish_job_event.set() + await asyncio.sleep(0.01) + status = await app.job_manager.get_job_status_async(job_id) + assert status == Status.SUCCEEDED + + +async def test_abort_async_job_preventing_cancellation(app: App, worker): + """ + Tests that an async job can prevent itself from being aborted + """ + + @app.task(queue="yay", name="task_func") + async def task_func(): + try: + await asyncio.sleep(0.2) + except asyncio.CancelledError: + pass + + job_id = await task_func.defer_async() + + await start_worker(worker) + await app.job_manager.cancel_job_by_id_async(job_id, abort=True) + await asyncio.sleep(0.01) + status = await app.job_manager.get_job_status_async(job_id) + assert status == Status.SUCCEEDED + + @pytest.mark.parametrize( "worker", [