diff --git a/docs/howto/advanced/cancellation.md b/docs/howto/advanced/cancellation.md index c8bddfe2c..7195de854 100644 --- a/docs/howto/advanced/cancellation.md +++ b/docs/howto/advanced/cancellation.md @@ -69,4 +69,9 @@ async def my_task(context): `context.should_abort()` and `context.should_abort_async()` does poll the database and might flood the database. Ensure you do it only sometimes and not from too many parallel tasks. + +You can use an optional `cache` parameter for limiting the frequency of +database requests. For example, calling `context.should_abort(cache=10)` resp. +`await context.should_abort_async(cache=10)` will reuse the cached status for +the specified number of seconds without polling the database. ::: diff --git a/docs/reference.rst b/docs/reference.rst index dc6302e0f..eb4d81c97 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -39,7 +39,8 @@ When tasks are created with argument ``pass_context``, they are provided a `JobContext` argument: .. autoclass:: procrastinate.JobContext - :members: app, worker_name, worker_queues, job, task + :members: app, worker_name, worker_queues, job, task, + should_abort, should_abort_async Blueprints ---------- diff --git a/procrastinate/job_context.py b/procrastinate/job_context.py index a35fe5582..0ed9b2aa0 100644 --- a/procrastinate/job_context.py +++ b/procrastinate/job_context.py @@ -49,7 +49,7 @@ class JobContext: Name of the worker (may be useful for logging) worker_queues : ``Optional[Iterable[str]]`` Queues listened by this worker - worker_id : ``int``` + worker_id : ``int`` In case there are multiple async sub-workers, this is the id of the sub-worker. job : `Job` Current `Job` instance @@ -65,6 +65,7 @@ class JobContext: task: tasks.Task | None = None job_result: JobResult = attr.ib(factory=JobResult) additional_context: dict = attr.ib(factory=dict) + cache: dict = attr.ib(factory=dict) def log_extra(self, action: str, **kwargs: Any) -> types.JSONDict: extra: types.JSONDict = { @@ -102,20 +103,70 @@ def job_description(self, current_timestamp: float) -> str: return message - def should_abort(self) -> bool: + def should_abort(self, cache: int | None = None) -> bool: + """ + Check if the job should be aborted. + + Parameters + ---------- + cache : ``int``, optional + Cache the job status (in seconds) to reduce the number of database requests + + Returns + ------- + ``bool`` + ``True`` if the job should be aborted, ``False`` otherwise + """ assert self.app assert self.job assert self.job.id job_id = self.job.id - status = self.app.job_manager.get_job_status(job_id) + + if cache is not None: + current_time = time.time() + last_checked_time = self.cache.get("job_status_last_checked") + if last_checked_time is None or current_time - last_checked_time >= cache: + self.cache["job_status_last_checked"] = current_time + self.cache["cached_job_status"] = self.app.job_manager.get_job_status( + self.job.id + ) + status = self.cache["cached_job_status"] + else: + status = self.app.job_manager.get_job_status(job_id) + return status == jobs.Status.ABORTING - async def should_abort_async(self) -> bool: + async def should_abort_async(self, cache: int | None = None) -> bool: + """ + Check if the job should be aborted. + + Parameters + ---------- + cache : ``int``, optional + Cache the job status (in seconds) to reduce the number of database requests + + Returns + ------- + ``bool`` + ``True`` if the job should be aborted, ``False`` otherwise + """ assert self.app assert self.job assert self.job.id job_id = self.job.id - status = await self.app.job_manager.get_job_status_async(job_id) + + if cache is not None: + current_time = time.time() + last_checked_time = self.cache.get("job_status_last_checked") + if last_checked_time is None or current_time - last_checked_time >= cache: + self.cache["job_status_last_checked"] = current_time + self.cache[ + "cached_job_status" + ] = await self.app.job_manager.get_job_status_async(job_id) + status = self.cache["cached_job_status"] + else: + status = await self.app.job_manager.get_job_status_async(job_id) + return status == jobs.Status.ABORTING diff --git a/tests/unit/test_job_context.py b/tests/unit/test_job_context.py index 0d3a5fe0d..59bfe41e7 100644 --- a/tests/unit/test_job_context.py +++ b/tests/unit/test_job_context.py @@ -2,7 +2,7 @@ import pytest -from procrastinate import job_context +from procrastinate import job_context, jobs @pytest.mark.parametrize( @@ -122,3 +122,59 @@ async def test_should_not_abort(app, job_factory): context = job_context.JobContext(app=app, job=job) assert context.should_abort() is False assert await context.should_abort_async() is False + + +async def test_should_abort_with_cache(app, job_factory, mocker): + app.job_manager.get_job_status = mocker.Mock(return_value=jobs.Status.DOING) + await app.job_manager.defer_job_async(job=job_factory()) + job = await app.job_manager.fetch_job(queues=None) + context = job_context.JobContext(app=app, job=job) + assert "job_status_last_checked" not in context.cache + assert "cached_job_status" not in context.cache + + mocker.patch("procrastinate.job_context.time.time", return_value=1000000000) + assert context.should_abort(cache=10) is False + assert context.cache["job_status_last_checked"] == 1000000000 + assert context.cache["cached_job_status"] == jobs.Status.DOING + + mocker.patch("procrastinate.job_context.time.time", return_value=1000000005) + assert context.should_abort(cache=10) is False + assert context.cache["job_status_last_checked"] == 1000000000 + assert context.cache["cached_job_status"] == jobs.Status.DOING + app.job_manager.get_job_status.assert_called_once_with(job.id) + app.job_manager.get_job_status.reset_mock() + + mocker.patch("procrastinate.job_context.time.time", return_value=1000000010) + assert context.should_abort(cache=10) is False + assert context.cache["job_status_last_checked"] == 1000000010 + assert context.cache["cached_job_status"] == jobs.Status.DOING + app.job_manager.get_job_status.assert_called_once_with(job.id) + + +async def test_should_abort_async_with_cache(app, job_factory, mocker): + app.job_manager.get_job_status_async = mocker.AsyncMock( + return_value=jobs.Status.DOING + ) + await app.job_manager.defer_job_async(job=job_factory()) + job = await app.job_manager.fetch_job(queues=None) + context = job_context.JobContext(app=app, job=job) + assert "job_status_last_checked" not in context.cache + assert "cached_job_status" not in context.cache + + mocker.patch("procrastinate.job_context.time.time", return_value=1000000000) + assert await context.should_abort_async(cache=10) is False + assert context.cache["job_status_last_checked"] == 1000000000 + assert context.cache["cached_job_status"] == jobs.Status.DOING + + mocker.patch("procrastinate.job_context.time.time", return_value=1000000005) + assert await context.should_abort_async(cache=10) is False + assert context.cache["job_status_last_checked"] == 1000000000 + assert context.cache["cached_job_status"] == jobs.Status.DOING + app.job_manager.get_job_status_async.assert_awaited_once_with(job.id) + app.job_manager.get_job_status_async.reset_mock() + + mocker.patch("procrastinate.job_context.time.time", return_value=1000000010) + assert await context.should_abort_async(cache=10) is False + assert context.cache["job_status_last_checked"] == 1000000010 + assert context.cache["cached_job_status"] == jobs.Status.DOING + app.job_manager.get_job_status_async.assert_awaited_once_with(job.id)