Skip to content

Commit

Permalink
Revert refactor of Redis keys in test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
vishesh10 committed Dec 11, 2023
1 parent b9a5e87 commit 94eba11
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 150 deletions.
7 changes: 4 additions & 3 deletions tasktiger/redis_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from redis import Redis

from ._internal import ACTIVE, ERROR, QUEUED, SCHEDULED
from .constants import EXECUTIONS, EXECUTIONS_COUNT, TASK

try:
from redis.commands.core import Script
Expand Down Expand Up @@ -595,9 +596,9 @@ def _bool_to_str(v: bool) -> str:
def _none_to_empty_str(v: Optional[str]) -> str:
return v or ""

key_task_id = key_func("task", id)
key_task_id_executions = key_func("task", id, "executions")
key_task_id_executions_count = key_func("task", id, "executions_count")
key_task_id = key_func(TASK, id)
key_task_id_executions = key_func(TASK, id, EXECUTIONS)
key_task_id_executions_count = key_func(TASK, id, EXECUTIONS_COUNT)
key_from_state = key_func(from_state)
key_to_state = key_func(to_state) if to_state else ""
key_active_queue = key_func(ACTIVE, queue)
Expand Down
15 changes: 8 additions & 7 deletions tasktiger/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
serialize_func_name,
serialize_retry_method,
)
from .constants import EXECUTIONS, EXECUTIONS_COUNT, TASK
from .exceptions import QueueFullException, TaskImportError, TaskNotFound
from .runner import BaseRunner, get_runner_class
from .types import RetryStrategy
Expand Down Expand Up @@ -393,7 +394,7 @@ def delay(

pipeline = tiger.connection.pipeline()
pipeline.sadd(tiger._key(state), self.queue)
pipeline.set(tiger._key("task", self.id), serialized_task)
pipeline.set(tiger._key(TASK, self.id), serialized_task)
# In case of unique tasks, don't update the score.
tiger.scripts.zadd(
tiger._key(state, self.queue),
Expand Down Expand Up @@ -454,11 +455,11 @@ def from_id(
latest). If the task doesn't exist, None is returned.
"""
pipeline = tiger.connection.pipeline()
pipeline.get(tiger._key("task", task_id))
pipeline.get(tiger._key(TASK, task_id))
pipeline.zscore(tiger._key(state, queue), task_id)
if load_executions:
pipeline.lrange(
tiger._key("task", task_id, "executions"), -load_executions, -1
tiger._key(TASK, task_id, EXECUTIONS), -load_executions, -1
)
(
serialized_data,
Expand Down Expand Up @@ -526,10 +527,10 @@ def tasks_from_queue(
]
if load_executions:
pipeline = tiger.connection.pipeline()
pipeline.mget([tiger._key("task", item[0]) for item in items])
pipeline.mget([tiger._key(TASK, item[0]) for item in items])
for item in items:
pipeline.lrange(
tiger._key("task", item[0], "executions"),
tiger._key(TASK, item[0], EXECUTIONS),
-load_executions,
-1,
)
Expand Down Expand Up @@ -586,8 +587,8 @@ def n_executions(self) -> int:
Queries and returns the number of past task executions.
"""
pipeline = self.tiger.connection.pipeline()
pipeline.exists(self.tiger._key("task", self.id))
pipeline.get(self.tiger._key("task", self.id, "executions_count"))
pipeline.exists(self.tiger._key(TASK, self.id))
pipeline.get(self.tiger._key(TASK, self.id, EXECUTIONS_COUNT))

exists, executions_count = pipeline.execute()
if not exists:
Expand Down
15 changes: 7 additions & 8 deletions tasktiger/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
serialize_func_name,
serialize_retry_method,
)
from .constants import EXECUTIONS, EXECUTIONS_COUNT, TASK
from .exceptions import (
RetryException,
StopRetry,
Expand Down Expand Up @@ -345,7 +346,7 @@ def _worker_queue_expired_tasks(self) -> None:
self.config["REQUEUE_EXPIRED_TASKS_BATCH_SIZE"],
)

for (queue, task_id) in task_data:
for queue, task_id in task_data:
self.log.debug("expiring task", queue=queue, task_id=task_id)
self._did_work = True
try:
Expand Down Expand Up @@ -374,7 +375,7 @@ def _worker_queue_expired_tasks(self) -> None:
# have a task without a task object.

# XXX: Ideally, the following block should be atomic.
if not self.connection.get(self._key("task", task_id)):
if not self.connection.get(self._key(TASK, task_id)):
self.log.error("not found", queue=queue, task_id=task_id)
task = Task(
self.tiger,
Expand Down Expand Up @@ -812,7 +813,7 @@ def _process_queue_tasks(

# Get all tasks
serialized_tasks = self.connection.mget(
[self._key("task", task_id) for task_id in task_ids]
[self._key(TASK, task_id) for task_id in task_ids]
)

# Parse tasks
Expand Down Expand Up @@ -1053,7 +1054,7 @@ def _mark_done() -> None:
should_log_error = True
# Get execution info (for logging and retry purposes)
execution = self.connection.lindex(
self._key("task", task.id, "executions"), -1
self._key(TASK, task.id, EXECUTIONS), -1
)

if execution:
Expand Down Expand Up @@ -1242,10 +1243,8 @@ def _store_task_execution(
serialized_execution = json.dumps(execution)

for task in tasks:
executions_key = self._key("task", task.id, "executions")
executions_count_key = self._key(
"task", task.id, "executions_count"
)
executions_key = self._key(TASK, task.id, EXECUTIONS)
executions_count_key = self._key(TASK, task.id, EXECUTIONS_COUNT)

pipeline = self.connection.pipeline()
pipeline.incr(executions_count_key)
Expand Down
69 changes: 22 additions & 47 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,6 @@
linear,
)
from tasktiger._internal import serialize_func_name
from tasktiger.constants import (
ACTIVE,
ERROR,
EXECUTIONS,
EXECUTIONS_COUNT,
QUEUED,
REDIS_PREFIX,
SCHEDULED,
TASK,
)

from .config import DELAY
from .tasks import (
Expand Down Expand Up @@ -79,6 +69,7 @@ def teardown_method(self, method):
def _ensure_queues(
self, queued=None, active=None, error=None, scheduled=None
):

expected_queues = {
"queued": {name for name, n in (queued or {}).items() if n},
"active": {name for name, n in (active or {}).items() if n},
Expand All @@ -98,9 +89,7 @@ def _ensure_queue(typ, data):
task_ids = self.conn.zrange("t:%s:%s" % (typ, name), 0, -1)
assert len(task_ids) == n
ret[name] = [
json.loads(
self.conn.get(f"{REDIS_PREFIX}:{TASK}:%s" % task_id)
)
json.loads(self.conn.get("t:task:%s" % task_id))
for task_id in task_ids
]
assert [task["id"] for task in ret[name]] == task_ids
Expand Down Expand Up @@ -138,7 +127,7 @@ def test_simple_task(self):

Worker(self.tiger).run(once=True)
self._ensure_queues(queued={"default": 0})
assert not self.conn.exists(f"{REDIS_PREFIX}:{TASK}:%s" % task["id"])
assert not self.conn.exists("t:task:%s" % task["id"])

@pytest.mark.skipif(
sys.version_info < (3, 3), reason="__qualname__ unavailable"
Expand All @@ -151,7 +140,7 @@ def test_staticmethod_task(self):

Worker(self.tiger).run(once=True)
self._ensure_queues(queued={"default": 0})
assert not self.conn.exists(f"{REDIS_PREFIX}:{TASK}:%s" % task["id"])
assert not self.conn.exists("t:task:%s" % task["id"])

def test_task_delay(self):
decorated_task.delay(1, 2, a=3, b=4)
Expand Down Expand Up @@ -269,7 +258,7 @@ def test_exception_task(self, store_tracebacks):
assert task["func"] == "tests.tasks:exception_task"

executions = self.conn.lrange(
f"{REDIS_PREFIX}:{TASK}:%s:{EXECUTIONS}" % task["id"], 0, -1
"t:task:%s:executions" % task["id"], 0, -1
)
assert len(executions) == 1
execution = json.loads(executions[0])
Expand All @@ -285,9 +274,7 @@ def test_exception_task(self, store_tracebacks):
@pytest.mark.parametrize("max_stored_executions", [2, 3, 6, 11, None])
def test_max_stored_executions(self, max_stored_executions):
def _get_stored_executions():
return self.conn.llen(
f"{REDIS_PREFIX}:{TASK}:{task.id}:{EXECUTIONS}"
)
return self.conn.llen(f"t:task:{task.id}:executions")

task = self.tiger.delay(
exception_task,
Expand Down Expand Up @@ -322,7 +309,7 @@ def test_long_task_killed(self):
assert task["func"] == "tests.tasks:long_task_killed"

executions = self.conn.lrange(
f"{REDIS_PREFIX}:{TASK}:%s:{EXECUTIONS}" % task["id"], 0, -1
"t:task:%s:executions" % task["id"], 0, -1
)
assert len(executions) == 1
execution = json.loads(executions[0])
Expand Down Expand Up @@ -677,17 +664,9 @@ def test_retry_executions_count(self, count):
Worker(self.tiger).run(once=True)

assert (
int(
self.conn.get(
f"{REDIS_PREFIX}:{TASK}:{task.id}:{EXECUTIONS_COUNT}"
)
)
== count
)
assert (
self.conn.llen(f"{REDIS_PREFIX}:{TASK}:{task.id}:{EXECUTIONS}")
== count
int(self.conn.get(f"t:task:{task.id}:executions_count")) == count
)
assert self.conn.llen(f"t:task:{task.id}:executions") == count

def test_batch_1(self):
self.tiger.delay(batch_task, args=[1])
Expand Down Expand Up @@ -1080,15 +1059,11 @@ def test_update_scheduled_time(self):
task = Task(self.tiger, simple_task, unique=True)
task.delay(when=datetime.timedelta(minutes=5))
self._ensure_queues(scheduled={"default": 1})
old_score = self.conn.zscore(
f"{REDIS_PREFIX}:{SCHEDULED}:default", task.id
)
old_score = self.conn.zscore("t:scheduled:default", task.id)

task.update_scheduled_time(when=datetime.timedelta(minutes=6))
self._ensure_queues(scheduled={"default": 1})
new_score = self.conn.zscore(
f"{REDIS_PREFIX}:{SCHEDULED}:default", task.id
)
new_score = self.conn.zscore("t:scheduled:default", task.id)

# The difference can be slightly over 60 due to processing time, but
# shouldn't be much higher.
Expand Down Expand Up @@ -1356,16 +1331,16 @@ def test_task_disappears(self):
time.sleep(DELAY)

# Remove the task object while the task is processing.
assert self.conn.delete(f"{REDIS_PREFIX}:{TASK}:{task.id}") == 1
assert self.conn.delete("t:task:{}".format(task.id)) == 1

# Kill the worker while it's still processing the task.
os.kill(worker.pid, signal.SIGKILL)

# _ensure_queues() breaks here because it can't find the task
assert self.conn.scard(f"{REDIS_PREFIX}:{QUEUED}") == 0
assert self.conn.scard(f"{REDIS_PREFIX}:{ACTIVE}") == 1
assert self.conn.scard(f"{REDIS_PREFIX}:{ERROR}") == 0
assert self.conn.scard(f"{REDIS_PREFIX}:{SCHEDULED}") == 0
assert self.conn.scard("t:queued") == 0
assert self.conn.scard("t:active") == 1
assert self.conn.scard("t:error") == 0
assert self.conn.scard("t:scheduled") == 0

# Capture logger
errors = []
Expand All @@ -1381,10 +1356,10 @@ def fake_error(msg):
Worker(self.tiger).run(once=True)

assert len(errors) == 0
assert self.conn.scard(f"{REDIS_PREFIX}:{QUEUED}") == 0
assert self.conn.scard(f"{REDIS_PREFIX}:{ACTIVE}") == 1
assert self.conn.scard(f"{REDIS_PREFIX}:{ERROR}") == 0
assert self.conn.scard(f"{REDIS_PREFIX}:{SCHEDULED}") == 0
assert self.conn.scard("t:queued") == 0
assert self.conn.scard("t:active") == 1
assert self.conn.scard("t:error") == 0
assert self.conn.scard("t:scheduled") == 0

# After waiting and re-running the worker, queues will clear.
time.sleep(2 * DELAY)
Expand Down Expand Up @@ -1432,7 +1407,7 @@ def test_child_hanging_forever(self):
assert task["func"] == "tests.tasks:sleep_task"

executions = self.conn.lrange(
f"{REDIS_PREFIX}:{TASK}:%s:{EXECUTIONS}" % task["id"], 0, -1
"t:task:%s:executions" % task["id"], 0, -1
)
assert len(executions) == 1
execution = json.loads(executions[0])
Expand Down Expand Up @@ -1489,7 +1464,7 @@ def test_decorated_child_hard_timeout_precedence(self):
assert task["func"] == "tests.tasks:decorated_task_sleep_timeout"

executions = self.conn.lrange(
f"{REDIS_PREFIX}:{TASK}:%s:{EXECUTIONS}" % task["id"], 0, -1
"t:task:%s:executions" % task["id"], 0, -1
)
assert len(executions) == 1
execution = json.loads(executions[0])
Expand Down
3 changes: 1 addition & 2 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import structlog

from tasktiger import TaskTiger, Worker
from tasktiger.constants import REDIS_PREFIX, TASK
from tasktiger.logging import tasktiger_processor

from .test_base import BaseTestCase
Expand Down Expand Up @@ -56,7 +55,7 @@ def test_structlog_processor(self):

Worker(self.tiger).run(once=True)
self._ensure_queues(queued={"foo_qux": 0})
assert not self.conn.exists(f"{REDIS_PREFIX}:{TASK}:%s" % task["id"])
assert not self.conn.exists("t:task:%s" % task["id"])


class TestSetupStructlog(BaseTestCase):
Expand Down
Loading

0 comments on commit 94eba11

Please sign in to comment.