Skip to content

Commit

Permalink
stage unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ziyi Yang committed Aug 14, 2024
1 parent 4d94092 commit 734221b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 30 deletions.
31 changes: 2 additions & 29 deletions src/step_function/states.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import ABC

import boto3

Expand All @@ -12,10 +12,6 @@ class State(ABC):
def __init__(self, name: str):
self.name = name

@abstractmethod
def get_execution_time(self) -> float:
pass


class Task(State):
"""Task: a single Lambda function """
Expand All @@ -41,9 +37,6 @@ def get_output(self, aws_session: boto3.Session) -> str:
logger.debug(f"Finish invoking {self.function_name}, output: {output}")
return output

def get_execution_time(self) -> float:
return 1


class Parallel(State):
"""Parallel: parallel workflows (branches) with same input."""
Expand All @@ -55,14 +48,6 @@ def __init__(self, name: str):
def add_branch(self, workflow: "Workflow"):
self.branches.append(workflow)

def get_execution_time(self) -> float:
"""Returns the longest execution time among all branches."""
max_time = 0
for branch in self.branches:
branch_time = branch.get_execution_time()
max_time = max(max_time, branch_time)
return max_time


class Map(State):
"""Map: multiple same workflows with different inputs"""
Expand All @@ -73,17 +58,9 @@ def __init__(self, name: str):
self.workflow = None
self.items_path = ""

def set_workflow(self, workflow: "Workflow"):
def add_iteration(self, workflow: "Workflow"):
self.iterations.append(workflow)

def get_execution_time(self) -> float:
"""Returns the longest execution time among all branches."""
max_time = 0
for branch in self.iterations:
branch_time = branch.get_execution_time()
max_time = max(max_time, branch_time)
return max_time


class Workflow:
"""A workflow, containing a sequence of states."""
Expand All @@ -93,7 +70,3 @@ def __init__(self):

def add_state(self, state: State):
self.states.append(state)

def get_execution_time(self) -> float:
total_time = sum(state.get_execution_time() for state in self.states)
return total_time
3 changes: 2 additions & 1 deletion src/step_function/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ def _optimize_function(task: Task):
if error is None:
error = e
continue
logger.info(f"Finish optimizing all functions, {results}")
logger.info("Finish optimizing all functions")
print(f"Finish optimizing all functions, {results}")

if error:
raise error
Empty file added tests/step_function/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tests/step_function/test_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from unittest import mock

import boto3
import pytest

from src.exploration.aws.aws_invoker import AWSInvoker
from src.step_function.states import Task, Parallel, Map, Workflow


@pytest.fixture
def aws_session():
return boto3.Session(region_name="us-west-2")


def test_task_state_invoke(aws_session):
# Arrange
task_state = Task(name="test_task", function_name="test_lambda")
task_state.set_input("input_data")
mock_invoker = mock.Mock(spec=AWSInvoker)
mock_invoker.invoke_for_output.return_value = "output_data"
# Act
output = task_state.get_output(aws_session)
# Assert
assert output == "output_data"


def test_parallel_state_add_branch(aws_session):
# Arrange
parallel_state = Parallel(name="test_parallel")
workflow1 = Workflow()
workflow2 = Workflow()
# Act
parallel_state.add_branch(workflow1)
parallel_state.add_branch(workflow2)
# Assert
assert len(parallel_state.branches) == 2
assert parallel_state.branches[0] == workflow1
assert parallel_state.branches[1] == workflow2


def test_map_state_set_workflow(aws_session):
# Arrange
map_state = Map(name="test_map")
workflow = Workflow()
# Act
map_state.add_iteration(workflow)
# Assert
assert len(map_state.iterations) == 1
assert map_state.iterations[0] == workflow

0 comments on commit 734221b

Please sign in to comment.