Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abort async tasks using asyncio #1190

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions docs/howto/advanced/cancellation.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,19 @@ app.job_manager.cancel_job_by_id(33, abort=True)
await app.job_manager.cancel_job_by_id_async(33, abort=True)
```

Behind the scenes, the worker receives a Postgres notification every time a job is requested to abort, (unless `listen_notify=False`).

The worker also polls (respecting `polling_interval`) the database for abortion requests, as long as the worker is running at least one job (in the absence of running job, there is nothing to abort).

:::{note}
When a job is requested to abort and that job fails, it will not be retried (regardless of the retry strategy).
:::

## Handle an abortion request inside the task

In our task, we can check (for example, periodically) if the task should be
### Sync tasks

In a sync task, we can check (for example, periodically) if the task should be
aborted. If we want to respect that abortion request (we don't have to), we raise a
`JobAborted` error. Any message passed to `JobAborted` (e.g.
`raise JobAborted("custom message")`) will end up in the logs.
Expand All @@ -53,10 +63,31 @@ def my_task(context):
do_something_expensive()
```

Behind the scenes, the worker receives a Postgres notification every time a job is requested to abort, (unless `listen_notify=False`).
### Async tasks

The worker also polls (respecting `polling_interval`) the database for abortion requests, as long as the worker is running at least one job (in the absence of running job, there is nothing to abort).
For async tasks (coroutines), they are cancelled via the [asyncio cancellation](https://docs.python.org/3/library/asyncio-task.html#task-cancellation) mechasnism.

:::{note}
When a job is requested to abort and that job fails, it will not be retried (regardless of the retry strategy).
:::
```python
@app.task()
async def my_task():
do_something_synchronous()
# if the job is aborted while it waits for do_something to complete, asyncio.CancelledError will be raised here
await do_something()
```

If you want to have some custom behavior at cancellation time, use a combination of [shielding](https://docs.python.org/3/library/asyncio-task.html#shielding-from-cancellation) and capturing `except asyncio.CancelledError`.

```python
@app.task()
async def my_task():
try:
important_task = asyncio.create_task(something_important())
# shield something_important from being cancelled
await asyncio.shield(important_task)
except asyncio.CancelledError:
# capture the error and waits for something important to complete
await important_task
# raise if the job should be marked as aborted, or swallow CancelledError if the job should be
# marked as suceeeded
raise
```
55 changes: 43 additions & 12 deletions procrastinate/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ async def _persist_job_status(
job: jobs.Job,
status: jobs.Status,
retry_decision: retry.RetryDecision | None,
context: job_context.JobContext,
):
if retry_decision:
await self.app.job_manager.retry_job(
Expand All @@ -138,6 +139,13 @@ async def _persist_job_status(
job=job, status=status, delete_job=delete_job
)

self._job_ids_to_abort.discard(job.id)

self.logger.debug(
f"Acknowledged job completion {job.call_string}",
extra=self._log_extra(action="finish_task", context=context, status=status),
)

def _log_job_outcome(
self,
status: jobs.Status,
Expand Down Expand Up @@ -257,18 +265,20 @@ async def ensure_async() -> Callable[..., Awaitable]:
job_retry=job_retry,
exc_info=exc_info,
)
await self._persist_job_status(
job=job, status=status, retry_decision=retry_decision
)

self._job_ids_to_abort.discard(job.id)

self.logger.debug(
f"Acknowledged job completion {job.call_string}",
extra=self._log_extra(
action="finish_task", context=context, status=status
),
persist_job_status_task = asyncio.create_task(
self._persist_job_status(
job=job,
status=status,
retry_decision=retry_decision,
context=context,
)
)
try:
await asyncio.shield(persist_job_status_task)
except asyncio.CancelledError:
await persist_job_status_task
raise

async def _fetch_and_process_jobs(self):
"""Fetch and process jobs until there is no job left or asked to stop"""
Expand Down Expand Up @@ -372,8 +382,29 @@ async def _poll_jobs_to_abort(self):

def _handle_abort_jobs_requested(self, job_ids: Iterable[int]):
running_job_ids = {c.job.id for c in self._running_jobs.values() if c.job.id}
self._job_ids_to_abort |= set(job_ids)
self._job_ids_to_abort &= running_job_ids
new_job_ids_to_abort = (running_job_ids & set(job_ids)) - self._job_ids_to_abort

for process_job_task, context in self._running_jobs.items():
if context.job.id in new_job_ids_to_abort:
self._abort_job(process_job_task, context)

def _abort_job(
self, process_job_task: asyncio.Task, context: job_context.JobContext
):
self._job_ids_to_abort.add(context.job.id)

log_message: str
if not context.task:
log_message = "Received a request to abort a job but the job has no associated task. No action to perform"
elif not asyncio.iscoroutinefunction(context.task.func):
log_message = "Received a request to abort a synchronous job. Job is responsible for aborting by checking context.should_abort"
else:
log_message = "Received a request to abort an asynchronous job. Cancelling asyncio task"
process_job_task.cancel()

self.logger.debug(
log_message, extra=self._log_extra(action="abort_job", context=context)
)

async def _shutdown(self, side_tasks: list[asyncio.Task]):
"""
Expand Down
9 changes: 3 additions & 6 deletions tests/acceptance/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,9 @@ def example_task():

@pytest.mark.parametrize("mode", ["listen", "poll"])
async def test_abort_async_task(async_app: app_module.App, mode):
@async_app.task(queue="default", name="task1", pass_context=True)
async def task1(context):
while True:
await asyncio.sleep(0.02)
if context.should_abort():
raise JobAborted
@async_app.task(queue="default", name="task1")
async def task1():
await asyncio.sleep(0.5)

job_id = await task1.defer_async()

Expand Down
70 changes: 69 additions & 1 deletion tests/unit/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import cast

import pytest
from pytest_mock import MockerFixture

from procrastinate.app import App
from procrastinate.exceptions import JobAborted
Expand Down Expand Up @@ -509,7 +510,7 @@ def task_func(a, b):
assert "to retry" not in record.message


async def test_run_job_aborted(app: App, worker, caplog):
async def test_run_job_raising_job_aborted(app: App, worker, caplog):
caplog.set_level("INFO")

@app.task(queue="yay", name="task_func")
Expand All @@ -530,6 +531,73 @@ async def task_func():
assert "Aborted" in record.message


async def test_abort_async_job(app: App, worker):
@app.task(queue="yay", name="task_func")
async def task_func():
await asyncio.sleep(0.2)

job_id = await task_func.defer_async()

await start_worker(worker)
await app.job_manager.cancel_job_by_id_async(job_id, abort=True)
await asyncio.sleep(0.01)
status = await app.job_manager.get_job_status_async(job_id)
assert status == Status.ABORTED


async def test_abort_async_job_while_finishing(app: App, worker, mocker: MockerFixture):
"""
Tests that aborting a job after that job completes but before the job status is updated
does not prevent the job status from being updated
"""
connector = cast(InMemoryConnector, app.connector)
original_finish_job_run = connector.finish_job_run

complete_finish_job_event = asyncio.Event()

async def delayed_finish_job_run(**arguments):
await complete_finish_job_event.wait()
return await original_finish_job_run(**arguments)

connector.finish_job_run = mocker.AsyncMock(name="finish_job_run")
connector.finish_job_run.side_effect = delayed_finish_job_run

@app.task(queue="yay", name="task_func")
async def task_func():
pass

job_id = await task_func.defer_async()

await start_worker(worker)
await app.job_manager.cancel_job_by_id_async(job_id, abort=True)
await asyncio.sleep(0.01)
complete_finish_job_event.set()
await asyncio.sleep(0.01)
status = await app.job_manager.get_job_status_async(job_id)
assert status == Status.SUCCEEDED


async def test_abort_async_job_preventing_cancellation(app: App, worker):
"""
Tests that an async job can prevent itself from being aborted
"""

@app.task(queue="yay", name="task_func")
async def task_func():
try:
await asyncio.sleep(0.2)
except asyncio.CancelledError:
pass

job_id = await task_func.defer_async()

await start_worker(worker)
await app.job_manager.cancel_job_by_id_async(job_id, abort=True)
await asyncio.sleep(0.01)
status = await app.job_manager.get_job_status_async(job_id)
assert status == Status.SUCCEEDED


@pytest.mark.parametrize(
"worker",
[
Expand Down
Loading