Skip to content

Commit

Permalink
chore: add a stress test for joinedtaskworker
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Sep 3, 2024
1 parent 5a27ae4 commit b4e4904
Showing 1 changed file with 118 additions and 1 deletion.
119 changes: 118 additions & 1 deletion tests/planai/test_joined_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import threading
import time
import unittest
from typing import List, Set, Type

Expand Down Expand Up @@ -60,6 +62,7 @@ def consume_work(self, task: Task1):
class TaskWorker3(JoinedTaskWorker):
join_type: Type[TaskWorker] = TaskWorker1
_processed_count: int = PrivateAttr(0)
_processed_items: int = PrivateAttr(0)

def consume_work(self, task: Task2):
super().consume_work(task)
Expand All @@ -69,7 +72,9 @@ def consume_work_joined(self, tasks: List[Task2]):
if len(prefixes) != 1:
raise ValueError("All tasks must have the same prefix", prefixes)

self._processed_count += 1
with self._lock:
self._processed_count += 1
self._processed_items += len(tasks)


class TestJoinedTaskWorker(unittest.TestCase):
Expand Down Expand Up @@ -117,5 +122,117 @@ def test_joined_task_worker(self):
self.graph._thread_pool.shutdown(wait=True)


class InitialTask(Task):
data: str


class IntermediateTask(Task):
data: str
source: str


class FinalTask(Task):
data: List[str]


class InitialTaskWorker(TaskWorker):
output_types: Set[Type[Task]] = {InitialTask}

def consume_work(self, task: InitialTask):
# Generate multiple tasks
for i in range(3):
output = InitialTask(data=f"{task.data}-{i}")
self.publish_work(output, task)
time.sleep(random.uniform(0.001, 0.01)) # Simulate some work


class IntermediateTaskWorker(TaskWorker):
output_types: Set[Type[Task]] = {IntermediateTask}

def consume_work(self, task: InitialTask):
output = IntermediateTask(data=f"Processed-{task.data}", source=self.name)
for i in range(3):
self.publish_work(output, task)
time.sleep(random.uniform(0.001, 0.01)) # Simulate some work


final_task_data = []


class FinalJoinedTaskWorker(JoinedTaskWorker):
join_type: Type[TaskWorker] = InitialTaskWorker

def consume_work(self, task: IntermediateTask):
super().consume_work(task)

def consume_work_joined(self, tasks: List[IntermediateTask]):
with self._lock:
final_task_data.append(FinalTask(data=[task.data for task in tasks]))


class TestJoinedTaskWorkerStress(unittest.TestCase):
def setUp(self):
self.graph = Graph(name="Stress Test Graph")
self.dispatcher = Dispatcher(self.graph)
self.graph._dispatcher = self.dispatcher

self.initial_worker = InitialTaskWorker()
self.intermediate_worker = IntermediateTaskWorker()
self.final_worker = FinalJoinedTaskWorker()

self.graph.add_workers(
self.initial_worker, self.intermediate_worker, self.final_worker
)
self.graph.set_dependency(self.initial_worker, self.intermediate_worker).next(
self.final_worker
)

def test_joined_task_worker_stress(self):
num_initial_tasks = 100
initial_tasks = [
InitialTask(data=f"Initial {i}") for i in range(num_initial_tasks)
]

# Start the dispatcher in a separate thread
dispatch_thread = threading.Thread(target=self.dispatcher.dispatch)
dispatch_thread.start()

# Function to add initial work
def add_initial_work():
for task in initial_tasks:
self.dispatcher.add_work(self.initial_worker, task)

# Start adding work in a separate thread
add_work_thread = threading.Thread(target=add_initial_work)
add_work_thread.start()

# Wait for all work to be processed
add_work_thread.join()
self.dispatcher.wait_for_completion()
self.dispatcher.stop()
dispatch_thread.join()

# Check results
self.assertEqual(
self.dispatcher.total_completed_tasks,
num_initial_tasks * 3
+ num_initial_tasks * 3 * 3
+ num_initial_tasks * (3 + 1), # 3 results plus notify
)
self.assertEqual(self.dispatcher.total_failed_tasks, 0)

self.assertEqual(len(final_task_data), num_initial_tasks * 3)

# Verify that each final task contains exactly 3 intermediate task results
for task in final_task_data:
self.assertEqual(len(task.data), 3)

# Verify that the work queue is empty and there are no active tasks
self.assertEqual(self.dispatcher.work_queue.qsize(), 0)
self.assertEqual(self.dispatcher.active_tasks, 0)

self.graph._thread_pool.shutdown(wait=True)


if __name__ == "__main__":
unittest.main()

0 comments on commit b4e4904

Please sign in to comment.