From b6506ad7ba8abe314dfd648791e64b2af2f1b907 Mon Sep 17 00:00:00 2001 From: Ritesh Kadmawala Date: Tue, 31 Oct 2017 19:56:18 +0530 Subject: [PATCH] feat: Allow custom serializer and deserializers for task --- tasktiger/__init__.py | 10 ++++++++++ tasktiger/task.py | 13 ++++++------- tasktiger/test_helpers.py | 29 +++++++++++++++++++++++++++++ tasktiger/worker.py | 7 +++---- tests/test_base.py | 36 +++++++++++++++++++++++++++++++++++- tests/utils.py | 9 ++++++--- 6 files changed, 89 insertions(+), 15 deletions(-) diff --git a/tasktiger/__init__.py b/tasktiger/__init__.py index 2c81d82a..882f0329 100644 --- a/tasktiger/__init__.py +++ b/tasktiger/__init__.py @@ -159,6 +159,13 @@ def __init__(self, connection=None, config=None, setup_structlog=False): # If non-empty, a worker excludes the given queues from processing. 'EXCLUDE_QUEUES': [], + + # Serializer / Deserilaizer to use for serializing/deserializing tasks + + 'SERIALIZER': json.dumps, + + 'DESERIALIZER': json.loads + } if config: self.config.update(config) @@ -193,6 +200,9 @@ def __init__(self, connection=None, config=None, setup_structlog=False): # List of task functions that are executed periodically. self.periodic_task_funcs = {} + self._serialize = self.config['SERIALIZER'] + self._deserialize = self.config['DESERIALIZER'] + def _get_current_task(self): if g['current_tasks'] is None: raise RuntimeError('Must be accessed from within a task') diff --git a/tasktiger/task.py b/tasktiger/task.py index ca400e5b..d8b8b941 100644 --- a/tasktiger/task.py +++ b/tasktiger/task.py @@ -1,5 +1,4 @@ import datetime -import json import redis import time @@ -280,7 +279,7 @@ def delay(self, when=None): # When using ALWAYS_EAGER, make sure we have serialized the task to # ensure there are no serialization errors. - serialized_task = json.dumps(self._data) + serialized_task = self.tiger._serialize(self._data) if tiger.config['ALWAYS_EAGER'] and state == QUEUED: return self.execute() @@ -341,8 +340,8 @@ def from_id(self, tiger, queue, state, task_id, load_executions=0): serialized_executions = [] # XXX: No timestamp for now if serialized_data: - data = json.loads(serialized_data) - executions = [json.loads(e) for e in serialized_executions if e] + data = tiger._deserialize(serialized_data) + executions = [tiger._deserialize(e) for e in serialized_executions if e] return Task(tiger, queue=queue, _data=data, _state=state, _executions=executions) else: @@ -380,8 +379,8 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000, results = pipeline.execute() for serialized_data, serialized_executions, ts in zip(results[0], results[1:], tss): - data = json.loads(serialized_data) - executions = [json.loads(e) for e in serialized_executions if e] + data = tiger._deserialize(serialized_data) + executions = [tiger._deserialize(e) for e in serialized_executions if e] task = Task(tiger, queue=queue, _data=data, _state=state, _ts=ts, _executions=executions) @@ -390,7 +389,7 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000, else: data = tiger.connection.mget([tiger._key('task', item[0]) for item in items]) for serialized_data, ts in zip(data, tss): - data = json.loads(serialized_data) + data = tiger._deserialize(serialized_data) task = Task(tiger, queue=queue, _data=data, _state=state, _ts=ts) tasks.append(task) diff --git a/tasktiger/test_helpers.py b/tasktiger/test_helpers.py index 4a44367c..cff1f729 100644 --- a/tasktiger/test_helpers.py +++ b/tasktiger/test_helpers.py @@ -1,3 +1,7 @@ +import json +import datetime +import decimal + from .task import Task from .worker import Worker @@ -30,3 +34,28 @@ def run_worker(self, tiger, raise_on_errors=True, **kwargs): has_errors = True if has_errors and raise_on_errors: raise Exception('One or more tasks have failed.') + + +class CustomJSONEncoder(json.JSONEncoder): + """ + A JSON encoder that allows for more common Python data types. + + In addition to the defaults handled by ``json``, this also supports: + + * ``datetime.datetime`` + * ``datetime.date`` + * ``datetime.time`` + * ``decimal.Decimal`` + + """ + def default(self, data): + if isinstance(data, (datetime.datetime, datetime.date, datetime.time)): + return data.isoformat() + elif isinstance(data, decimal.Decimal): + return str(data) + else: + return super(CustomJSONEncoder, self).default(data) + + +def custom_serializer(obj): + return json.dumps(obj, cls=CustomJSONEncoder) diff --git a/tasktiger/worker.py b/tasktiger/worker.py index 4e4bd967..63532373 100644 --- a/tasktiger/worker.py +++ b/tasktiger/worker.py @@ -1,7 +1,6 @@ from collections import OrderedDict import errno import fcntl -import json import os import random import select @@ -327,7 +326,7 @@ def _execute_forked(self, tasks, log): ''.join(traceback.format_exception(*exc_info)) execution['success'] = success execution['host'] = socket.gethostname() - serialized_execution = json.dumps(execution) + serialized_execution = self.tiger._serialize(execution) for task in tasks: self.connection.rpush(self._key('task', task.id, 'executions'), serialized_execution) @@ -544,7 +543,7 @@ def _process_queue_tasks(self, queue, queue_lock, task_ids, now, log): tasks = [] for task_id, serialized_task in zip(task_ids, serialized_tasks): if serialized_task: - task_data = json.loads(serialized_task) + task_data = self.tiger._deserialize(serialized_task) else: # In the rare case where we don't find the task which is # queued (see ReliabilityTestCase.test_task_disappears), @@ -739,7 +738,7 @@ def _mark_done(): self._key('task', task.id, 'executions'), -1) if execution: - execution = json.loads(execution) + execution = self.tiger._deserialize(execution) if execution and execution.get('retry'): if 'retry_method' in execution: diff --git a/tests/test_base.py b/tests/test_base.py index 4590619c..09bd5995 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -8,9 +8,12 @@ import time from multiprocessing import Pool, Process +from decimal import Decimal + from tasktiger import (JobTimeoutException, StopRetry, Task, TaskNotFound, Worker, exponential, fixed, linear) from tasktiger._internal import serialize_func_name +from tasktiger.test_helpers import custom_serializer from .config import DELAY from .tasks import (batch_task, decorated_task, decorated_task_simple_func, @@ -23,8 +26,10 @@ class BaseTestCase: + CONFIG = {} + def setup_method(self, method): - self.tiger = get_tiger() + self.tiger = get_tiger(**self.CONFIG) self.conn = self.tiger.connection self.conn.flushdb() @@ -1012,3 +1017,32 @@ def test_single_worker_queue(self): self._ensure_queues() worker.join() + + +class TestCustomSerializer(BaseTestCase): + + CONFIG = { + 'SERIALIZER': custom_serializer + } + + def test_task(self): + tmpfile = tempfile.NamedTemporaryFile() + task_args = (tmpfile.name, 'test', 5) + task_kwargs = dict(a=datetime.datetime.now(), + b=Decimal("5.05")) + + self.tiger.delay(file_args_task, args=task_args, kwargs=task_kwargs) + queues = self._ensure_queues(queued={'default': 1}) + task = queues['queued']['default'][0] + assert task['func'] == 'tests.tasks:file_args_task' + + Worker(self.tiger).run(once=True) + self._ensure_queues(queued={'default': 0}) + json_data = tmpfile.read().decode('utf8') + assert json.loads(json_data) == { + 'args': ['test', 5], + 'kwargs': { + 'a': task_kwargs['a'].isoformat(), + 'b': str(task_kwargs['b']) + } + } \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index 1836643a..d08a2a3f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,7 +28,7 @@ def __exit__(self, *args): setattr(self.orig_obj, self.func_name, self.orig_func) -def get_tiger(): +def get_tiger(**kwargs): """ Sets up logging and returns a new tasktiger instance. """ @@ -38,7 +38,7 @@ def get_tiger(): ) logging.basicConfig(format='%(message)s') conn = redis.Redis(db=TEST_DB, decode_responses=True) - tiger = TaskTiger(connection=conn, config={ + config = { # We need this 0 here so we don't pick up scheduled tasks when # doing a single worker run. 'SELECT_TIMEOUT': 0, @@ -56,7 +56,10 @@ def get_tiger(): }, 'SINGLE_WORKER_QUEUES': ['swq'], - }) + } + + config.update(kwargs) + tiger = TaskTiger(connection=conn, config=config) tiger.log.setLevel(logging.CRITICAL) return tiger