Skip to content

Commit

Permalink
Parrotfish for Step Function, Execution Time Constraint Optimization (#…
Browse files Browse the repository at this point in the history
…164)

* feat: get critical path of step function

* feat: optimize functions on critical path

* feat: optimize individual functions

* merge execution time optimization into StepFunction

* update gitignore

* print final memorise

* remove optimize individual functions

* prepare for PR

* update method names

* add unit tests

* separate execution_time_optimizer and add unit tests

* remove execution time optimization from step_function

---------

Co-authored-by: Ziyi Yang <[email protected]>
  • Loading branch information
zyangbi and Ziyi Yang authored Aug 28, 2024
1 parent e73a990 commit 0c8a2f6
Show file tree
Hide file tree
Showing 13 changed files with 456 additions and 41 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,3 @@ venv/
.coverage

dump/

src/step_function/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"arn": "arn:aws:states:us-west-2:898429789601:stateMachine:ImageProcessing",
"region": "us-west-2",
"payload": {}
"arn": "arn:aws:states:us-west-2:898429789601:stateMachine:ImageProcessing",
"region": "us-west-2",
"payload": {}
}
6 changes: 5 additions & 1 deletion src/configuration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ dynamic_sampling_params.max_sample_count=5
and when the calculated coefficient of variation reaches this threshold we terminate the dynamic sampling (Default is 0.05),
} (Optional),
"max_number_of_invocation_attempts": The maximum number of attempts per invocation when this number is reached an error is raised. (Optional, Default is 5)
"constraint_execution_time_threshold": The step function execution time threshold constraint. We leverages the execution time model and step function workflow structure
to recommend a configuration that minimizes cost while adhering to the specified execution time constraint. (Optional, Default is +infinity)
"memory_size_increment": The step size by which memory size is increased to meet execution time threshold. (Optional, Default is 10)
}
```

Expand All @@ -121,7 +124,8 @@ dynamic_sampling_params.max_sample_count=5
{
"arn": "example_step_function_arn",
"region": "example_region",
"payload": "payload"
"payload": "payload",
"constraint_execution_time_threshold": 5000
}
```

1 change: 1 addition & 0 deletions src/configuration/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
MAX_TOTAL_SAMPLE_COUNT = 20
MIN_SAMPLE_PER_CONFIG = 4
TERMINATION_THRESHOLD = 3
MEMORY_SIZE_INCREMENT = 10

LOG_LEVEL = logging.WARNING
6 changes: 5 additions & 1 deletion src/configuration/step_function_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class StepFunctionConfiguration:
def __init__(self, config_file: TextIO):
def __init__(self, config_file: Union[TextIO, dict]):
self._load_config_schema()

# Setup default values
Expand All @@ -16,6 +16,8 @@ def __init__(self, config_file: TextIO):
self.max_total_sample_count = MAX_TOTAL_SAMPLE_COUNT
self.min_sample_per_config = MIN_SAMPLE_PER_CONFIG
self.max_number_of_invocation_attempts = MAX_NUMBER_OF_INVOCATION_ATTEMPTS
self.memory_size_increment = MEMORY_SIZE_INCREMENT
self.constraint_execution_time_threshold = None

# Parse the configuration file
self._deserialize(config_file)
Expand Down Expand Up @@ -58,6 +60,8 @@ def _load_config_schema(self):
},
},
"max_number_of_invocation_attempts": {"type": "integer", "minimum": 0},
"constraint_execution_time_threshold": {"type": "integer", "minimum": 1},
"memory_size_increment": {"type": "integer", "minimum": 1},
},
"required": ["arn", "region", "payload"],
"if": {"not": {"required": ["payload"]}},
Expand Down
3 changes: 3 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def main():
exit(1)

if args.step_function:
# Create step function
step_function = StepFunction(configuration)

# Run cost and execution time optimization
step_function.optimize()

else:
Expand Down
2 changes: 1 addition & 1 deletion src/objective/parametric_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ParametricFunction:
bounds (tuple): Lower and upper bounds on parameters.
"""

function: callable = lambda x, a0, a1, a2: a0 + a1 * np.exp(-x / a2)
function: callable = lambda x, a0, a1, a2: (a0 + a1 * np.exp(-x / a2)) if a2 != 0 else a0
bounds: tuple = ([-np.inf, -np.inf, -np.inf], [np.inf, np.inf, np.inf])
params: any = None

Expand Down
99 changes: 99 additions & 0 deletions src/step_function/execution_time_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from src.exception.step_function_error import StepFunctionError
from src.logger import logger


class ExecutionTimeOptimizer:
def __init__(self, workflow, function_tasks_dict, config):
self.workflow = workflow
self.function_tasks_dict = function_tasks_dict
self.memory_increment = config.memory_size_increment
self.execution_time_threshold = config.constraint_execution_time_threshold

def optimize_for_execution_time_constraint(self):
"""Optimize the step function for execution time constraints."""
if self.execution_time_threshold is None:
logger.warning("No execution time threshold.")
return

critical_path_tasks, critical_path_time = self.workflow.get_critical_path()
logger.info(
f"Start optimizing step function for execution time, time: {critical_path_time}ms, threshold: {self.execution_time_threshold}ms, cost: {self.workflow.get_cost()}."
)

cost_increases = self._initialize_cost_increases()

while critical_path_time > self.execution_time_threshold:
time_reductions = self._calculate_time_reductions(critical_path_tasks)
best_function = self._find_best_function_to_optimize(cost_increases, time_reductions)

if best_function:
self._update_memory_size_and_cost(best_function, cost_increases)
else:
raise StepFunctionError("Execution time threshold too low.")

critical_path_tasks, critical_path_time = self.workflow.get_critical_path()
logger.debug(
f"Optimized function {best_function}, time: {critical_path_time}ms, cost: {self.workflow.get_cost()}.\n"
)

logger.info(
f"Finish optimizing step function for execution time, time: {critical_path_time}ms, threshold: {self.execution_time_threshold}ms, cost: {self.workflow.get_cost()}.\n"
)
self._print_memory_sizes()

def _initialize_cost_increases(self):
"""Initialize the cost increases for each function."""
cost_increases = {}
for function in self.function_tasks_dict:
cost_increases[function] = 0.0
for task in self.function_tasks_dict[function]:
original_cost = task.get_cost(task.memory_size)
new_cost = task.get_cost(task.memory_size + self.memory_increment)
cost_increases[function] += new_cost - original_cost
return cost_increases

def _calculate_time_reductions(self, critical_path_tasks):
"""Calculate time reductions for tasks on the critical path."""
time_reductions = {}
for task in critical_path_tasks:
if task.memory_size + self.memory_increment > task.max_memory_size:
continue

original_time = task.get_execution_time()
new_time = task.get_execution_time(task.memory_size + self.memory_increment)

if task.function_name not in time_reductions:
time_reductions[task.function_name] = 0.0
time_reductions[task.function_name] += original_time - new_time
return time_reductions

def _find_best_function_to_optimize(self, cost_increases, time_reductions):
"""Find the function with the lowest cost to time reduction ratio."""
best_function = None
lowest_ratio = float('inf')
for function_name in time_reductions:
if time_reductions[function_name] > 0:
ratio = cost_increases[function_name] / time_reductions[function_name]
logger.debug(
f"ratio: {ratio}, {function_name}, {self.function_tasks_dict[function_name][0].memory_size}MB, {cost_increases[function_name]}, {time_reductions[function_name]}"
)

if ratio < lowest_ratio:
lowest_ratio = ratio
best_function = function_name
return best_function

def _update_memory_size_and_cost(self, best_function, cost_increases):
"""Increase memory size of the best function and update cost increases."""
cost_increases[best_function] = 0.0
for task in self.function_tasks_dict[best_function]:
task.increase_memory_size(self.memory_increment)
original_cost = task.get_cost()
new_cost = task.get_cost(task.memory_size + self.memory_increment)
cost_increases[best_function] += new_cost - original_cost

def _print_memory_sizes(self):
"""Print memory sizes after optimization."""
print("Finish optimizing step function for execution time, optimized memory sizes:")
for function in self.function_tasks_dict:
print(f"{function}: {self.function_tasks_dict[function][0].memory_size}MB")
86 changes: 82 additions & 4 deletions src/step_function/states.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from abc import ABC, abstractmethod
from typing import Tuple

import boto3

Expand All @@ -12,14 +13,22 @@ class State(ABC):
def __init__(self, name: str):
self.name = name

@abstractmethod
def get_cost(self) -> float:
pass


class Task(State):
"""Task: a single Lambda function """

def __init__(self, name: str, function_name: str):
super().__init__(name)
self.function_name = function_name
self.input = None
self.param_function = None
self.memory_size = None
self.initial_memory_size = None
self.max_memory_size = None

def set_input(self, input: str):
self.input = input
Expand All @@ -37,6 +46,29 @@ def get_output(self, aws_session: boto3.Session) -> str:
logger.debug(f"Finish invoking {self.function_name}, output: {output}")
return output

def increase_memory_size(self, increment: int):
self.memory_size += increment

def decrease_memory_size(self, decrement: int):
self.memory_size -= decrement

def reset_memory_size(self):
self.memory_size = self.initial_memory_size

def get_execution_time(self, memory_size: int = None):
if memory_size is not None:
execution_time = self.param_function(memory_size)
else:
execution_time = self.param_function(self.memory_size)
return execution_time

def get_cost(self, memory_size: int = None):
if memory_size is not None:
execution_time = memory_size * self.param_function(memory_size)
else:
execution_time = self.memory_size * self.param_function(self.memory_size)
return execution_time


class Parallel(State):
"""Parallel: parallel workflows (branches) with same input."""
Expand All @@ -48,19 +80,46 @@ def __init__(self, name: str):
def add_branch(self, workflow: "Workflow"):
self.branches.append(workflow)

def get_critical_path(self) -> Tuple[list[Task], float]:
"""Get tasks on critical path and execution time."""
max_time = 0.0
critical_path = None
for workflow in self.branches:
states, time = workflow.get_critical_path()
if time > max_time:
max_time = time
critical_path = states
return critical_path, max_time

def get_cost(self) -> float:
return sum(workflow.get_cost() for workflow in self.branches)


class Map(State):
"""Map: multiple same workflows with different inputs"""
"""Map: multiple same workflows with different inputs."""

def __init__(self, name: str):
super().__init__(name)
self.iterations: list[Workflow] = []
self.items_path = None
self.workflow_def = None
self.items_path = ""
self.iterations: list[Workflow] = []

def add_iteration(self, workflow: "Workflow"):
self.iterations.append(workflow)

def get_critical_path(self) -> Tuple[list[Task], float]:
max_time = 0.0
critical_path = None
for workflow in self.iterations:
states, time = workflow.get_critical_path()
if time > max_time:
max_time = time
critical_path = states
return critical_path, max_time

def get_cost(self) -> float:
return sum(workflow.get_cost() for workflow in self.iterations)


class Workflow:
"""A workflow, containing a sequence of states."""
Expand All @@ -70,3 +129,22 @@ def __init__(self):

def add_state(self, state: State):
self.states.append(state)

def get_critical_path(self) -> Tuple[list[Task], float]:
critical_path: list[Task] = []
total_time = 0.0

for state in self.states:
if isinstance(state, Task):
critical_path.append(state)
time = state.get_execution_time()
total_time += time
elif isinstance(state, (Parallel, Map)):
states, time = state.get_critical_path()
critical_path.extend(states)
total_time += time

return critical_path, total_time

def get_cost(self) -> float:
return sum(state.get_cost() for state in self.states)
Loading

0 comments on commit 0c8a2f6

Please sign in to comment.