Skip to content

Commit

Permalink
Support for custom task runners (#175)
Browse files Browse the repository at this point in the history
Allow specifying a Python class to influence task running behavior.
  • Loading branch information
thomasst authored Feb 16, 2021
1 parent 6064c95 commit 6d96607
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 14 deletions.
23 changes: 23 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,28 @@ The following options are supported by both ``delay`` and the task decorator:
For example, to retry a task 3 times (for a total of 4 executions), and wait
60 seconds between executions, pass ``retry_method=fixed(60, 3)``.

- ``runner_class``

If given, a Python class can be specified to influence task running behavior.
The runner class should inherit ``tasktiger.runner.BaseRunner`` and implement
the task execution behavior. The default implementation is available in
``tasktiger.runner.DefaultRunner``. The following behavior can be achieved:

- Execute specific code before or after the task is executed (in the forked
child process), or customize the way task functions are called in either
single or batch processing.

Note that if you want to execute specific code for all tasks,
you should use the ``CHILD_CONTEXT_MANAGERS`` configuration option.

- Control the hard timeout behavior of a task.

- Execute specific code in the main worker process after a task failed
permanently.

This is an advanced feature and the interface and requirements of the runner
class can change in future TaskTiger versions.

The following options can be only specified in the task decorator:

- ``batch``
Expand Down Expand Up @@ -408,6 +430,7 @@ Example usage:
.. code:: python
from tasktiger.exceptions import RetryException
from tasktiger.retry import exponential, fixed
def my_task():
if not ready():
Expand Down
90 changes: 90 additions & 0 deletions tasktiger/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from ._internal import import_attribute
from .exceptions import TaskImportError
from .timeouts import UnixSignalDeathPenalty


class BaseRunner:
"""
Base implementation of the task runner.
"""

def __init__(self, tiger):
self.tiger = tiger

def run_single_task(self, task, hard_timeout):
"""
Run the given task using the hard timeout in seconds.
This is called inside of the forked process.
"""
raise NotImplementedError("Single tasks are not supported.")

def run_batch_tasks(self, tasks, hard_timeout):
"""
Run the given tasks using the hard timeout in seconds.
This is called inside of the forked process.
"""
raise NotImplementedError("Batch tasks are not supported.")

def run_eager_task(self, task):
"""
Run the task eagerly and return the value.
Note that the task function could be a batch function.
"""
raise NotImplementedError("Eager tasks are not supported.")

def on_permanent_error(self, task, execution):
"""
Called if the task fails permanently.
A task fails permanently if its status is set to ERROR and it is no
longer retried.
This is called in the main worker process.
"""


class DefaultRunner(BaseRunner):
"""
Default implementation of the task runner.
"""

def run_single_task(self, task, hard_timeout):
with UnixSignalDeathPenalty(hard_timeout):
task.func(*task.args, **task.kwargs)

def run_batch_tasks(self, tasks, hard_timeout):
params = [{'args': task.args, 'kwargs': task.kwargs} for task in tasks]
func = tasks[0].func
with UnixSignalDeathPenalty(hard_timeout):
func(params)

def run_eager_task(self, task):
func = task.func
is_batch_func = getattr(func, '_task_batch', False)

if is_batch_func:
return func([{'args': task.args, 'kwargs': task.kwargs}])
else:
return func(*task.args, **task.kwargs)


def get_runner_class(log, tasks):
runner_class_paths = {task.serialized_runner_class for task in tasks}
if len(runner_class_paths) > 1:
log.error(
"cannot mix multiple runner classes",
runner_class_paths=", ".join(str(p) for p in runner_class_paths),
)
raise ValueError("Found multiple runner classes in batch task.")

runner_class_path = runner_class_paths.pop()
if runner_class_path:
try:
return import_attribute(runner_class_path)
except TaskImportError:
log.error('could not import runner class', func=retry_func)
raise
return DefaultRunner
21 changes: 16 additions & 5 deletions tasktiger/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import time

from ._internal import *
from .exceptions import QueueFullException, TaskNotFound
from .exceptions import QueueFullException, TaskImportError, TaskNotFound
from .runner import get_runner_class

__all__ = ['Task']

Expand All @@ -26,6 +27,7 @@ def __init__(
retry_on=None,
retry_method=None,
max_queue_size=None,
runner_class=None,
# internal variables
_data=None,
_state=None,
Expand Down Expand Up @@ -76,6 +78,9 @@ def __init__(
if max_queue_size is None:
max_queue_size = getattr(func, '_task_max_queue_size', None)

if runner_class is None:
runner_class = getattr(func, '_task_runner_class', None)

# normalize falsy args/kwargs to empty structures
args = args or []
kwargs = kwargs or {}
Expand Down Expand Up @@ -110,6 +115,9 @@ def __init__(
]
if max_queue_size:
task['max_queue_size'] = max_queue_size
if runner_class:
serialized_runner_class = serialize_func_name(runner_class)
task['runner_class'] = serialized_runner_class

self._data = task

Expand Down Expand Up @@ -191,6 +199,10 @@ def func(self):
self._func = import_attribute(self.serialized_func)
return self._func

@property
def serialized_runner_class(self):
return self._data.get('runner_class')

@property
def ts(self):
"""
Expand Down Expand Up @@ -298,10 +310,9 @@ def execute(self):
g['tiger'] = self.tiger

try:
if is_batch_func:
return func([{'args': self.args, 'kwargs': self.kwargs}])
else:
return func(*self.args, **self.kwargs)
runner_class = get_runner_class(self.tiger.log, [self])
runner = runner_class(self.tiger)
return runner.run_eager_task(self)
finally:
g['current_task_is_batch'] = None
g['current_tasks'] = None
Expand Down
5 changes: 5 additions & 0 deletions tasktiger/tasktiger.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def task(
schedule=None,
batch=False,
max_queue_size=None,
runner_class=None,
):
"""
Function decorator that defines the behavior of the function when it is
Expand Down Expand Up @@ -318,6 +319,8 @@ def _wrap(func):
func._task_schedule = schedule
if max_queue_size is not None:
func._task_max_queue_size = max_queue_size
if runner_class is not None:
func._task_runner_class = runner_class

func.delay = _delay(func)

Expand Down Expand Up @@ -389,6 +392,7 @@ def delay(
retry_on=None,
retry_method=None,
max_queue_size=None,
runner_class=None,
):
"""
Queues a task. See README.rst for an explanation of the options.
Expand All @@ -407,6 +411,7 @@ def delay(
retry=retry,
retry_on=retry_on,
retry_method=retry_method,
runner_class=runner_class,
)

task.delay(when=when, max_queue_size=max_queue_size)
Expand Down
20 changes: 11 additions & 9 deletions tasktiger/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from .exceptions import RetryException, TaskNotFound
from .redis_semaphore import Semaphore
from .retry import *
from .runner import get_runner_class
from .stats import StatsThread
from .task import Task
from .timeouts import UnixSignalDeathPenalty, JobTimeoutException
from .timeouts import JobTimeoutException

if sys.version_info < (3, 3):
from contextlib2 import ExitStack
Expand Down Expand Up @@ -362,6 +363,9 @@ def _execute_forked(self, tasks, log):
try:
func = tasks[0].func

runner_class = get_runner_class(log, tasks)
runner = runner_class(self.tiger)

is_batch_func = getattr(func, '_task_batch', False)
g['tiger'] = self.tiger
g['current_task_is_batch'] = is_batch_func
Expand All @@ -371,10 +375,6 @@ def _execute_forked(self, tasks, log):
):
if is_batch_func:
# Batch process if the task supports it.
params = [
{'args': task.args, 'kwargs': task.kwargs}
for task in tasks
]
task_timeouts = [
task.hard_timeout
for task in tasks
Expand All @@ -387,8 +387,7 @@ def _execute_forked(self, tasks, log):
)

g['current_tasks'] = tasks
with UnixSignalDeathPenalty(hard_timeout):
func(params)
runner.run_batch_tasks(tasks, hard_timeout)

else:
# Process sequentially.
Expand All @@ -400,8 +399,7 @@ def _execute_forked(self, tasks, log):
)

g['current_tasks'] = [task]
with UnixSignalDeathPenalty(hard_timeout):
func(*task.args, **task.kwargs)
runner.run_single_task(task, hard_timeout)

except RetryException as exc:
execution['retry'] = True
Expand Down Expand Up @@ -1015,6 +1013,10 @@ def _mark_done():
_mark_done()
else:
task._move(from_state=ACTIVE, to_state=state, when=when)
if state == ERROR and task.serialized_runner_class:
runner_class = get_runner_class(log, [task])
runner = runner_class(self.tiger)
runner.on_permanent_error(task, execution)

def _worker_run(self):
"""
Expand Down
36 changes: 36 additions & 0 deletions tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from tasktiger import RetryException, TaskTiger
from tasktiger.retry import fixed
from tasktiger.runner import BaseRunner, DefaultRunner

from .config import DELAY, TEST_DB, REDIS_HOST
from .utils import get_tiger
Expand Down Expand Up @@ -179,3 +180,38 @@ class StaticTask(object):
@staticmethod
def task():
pass


class MyRunnerClass(BaseRunner):
def run_single_task(self, task, hard_timeout):
assert self.tiger.config == tiger.config
assert hard_timeout == 300
assert task.func is simple_task

with redis.Redis(
host=REDIS_HOST, db=TEST_DB, decode_responses=True
) as conn:
conn.set('task_id', task.id)

def run_batch_tasks(self, tasks, hard_timeout):
assert self.tiger.config == tiger.config
assert hard_timeout == 300
assert len(tasks) == 2

with redis.Redis(
host=REDIS_HOST, db=TEST_DB, decode_responses=True
) as conn:
conn.set('task_args', ",".join(str(t.args[0]) for t in tasks))

def run_eager_task(self, task):
return 123


class MyErrorRunnerClass(DefaultRunner):
def on_permanent_error(self, task, execution):
assert task.func is exception_task
assert execution["exception_name"] == "builtins:Exception"
with redis.Redis(
host=REDIS_HOST, db=TEST_DB, decode_responses=True
) as conn:
conn.set('task_id', task.id)
42 changes: 42 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
locked_task,
long_task_killed,
long_task_ok,
MyErrorRunnerClass,
MyRunnerClass,
non_batch_task,
retry_task,
retry_task_2,
Expand Down Expand Up @@ -1236,3 +1238,43 @@ def fake_error(msg):
self._ensure_queues()
assert len(errors) == 1
assert "not found" in errors[0]


class TestRunnerClass(BaseTestCase):
def test_custom_runner_class_single_task(self):
task = self.tiger.delay(simple_task, runner_class=MyRunnerClass)
Worker(self.tiger).run(once=True)
assert self.conn.get('task_id') == task.id
self.conn.delete('task_id')
self._ensure_queues()

def test_custom_runner_class_batch_task(self):
self.tiger.delay(batch_task, args=[1], runner_class=MyRunnerClass)
self.tiger.delay(batch_task, args=[2], runner_class=MyRunnerClass)
Worker(self.tiger).run(once=True)
assert self.conn.get('task_args') == "1,2"
self.conn.delete('task_args')
self._ensure_queues()

def test_mixed_runner_class_batch_task(self):
"""Ensure all tasks in a batch task must have the same runner class."""
self.tiger.delay(batch_task, args=[1], runner_class=MyRunnerClass)
self.tiger.delay(batch_task, args=[2])
Worker(self.tiger).run(once=True)
assert self.conn.get('task_args') is None
self._ensure_queues(error={'batch': 2})

def test_permanent_error(self):
task = self.tiger.delay(
exception_task, runner_class=MyErrorRunnerClass
)
Worker(self.tiger).run(once=True)
assert self.conn.get('task_id') == task.id
self.conn.delete('task_id')
self._ensure_queues(error={'default': 1})

def test_eager_task(self):
self.tiger.config['ALWAYS_EAGER'] = True
task = Task(self.tiger, simple_task, runner_class=MyRunnerClass)
assert task.delay() == 123
self._ensure_queues()

0 comments on commit 6d96607

Please sign in to comment.