Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/jerry/function level env #46

Merged
merged 5 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ workflow = MultiXServerlessWorkflow("workflow_name")
"only_regions": [["aws", "us-east-1"], ["aws", "us-east-2"], ["aws", "us-west-1"], ["aws", "us-west-2"]],
"forbidden_regions": None,
},
func_environment_variables=[
{
"key": "example_key",
"value": "example_value"
}
]
providers=[
{
"name": "aws",
Expand All @@ -101,6 +107,7 @@ The meaning of the different parameters is as follows:
- `regions_and_providers`: A dictionary that contains the regions and providers that the function can be deployed to. This can be used to override the global settings in the `config.yml`. If none or an empty dictionary is provided, the global config takes precedence. The dictionary has two keys:
- `only_regions`: A list of regions that the function can be deployed to. If this list is empty, the function can be deployed to any region.
- `forbidden_regions`: A list of regions that the function cannot be deployed to. If this list is empty, the function can be deployed to any region.
- `func_environment_variables`: This parameter represents a list of dictionaries, each designed for setting environment variables specifically for a function. Users must adhere to a structured format within each dictionary. This format requires two entries: "key" and "value". The "key" entry should contain the name of the environment variable, serving as an identifier. The "value" entry holds the corresponding value assigned to that variable.
jerryyiransun marked this conversation as resolved.
Show resolved Hide resolved
- `providers`: A list of providers that the function can be deployed to. This can be used to override the global settings in the `config.yml`. If a list of providers is specified at the function level this takes precedence over the global configurations. If none or an empty list is provided, the global config takes precedence. Each provider is a dictionary with two keys:
- `name`: The name of the provider. This is the name that is used directly in the physical representation of the workflow.
- `config`: A dictionary that contains the configuration for the specific provider.
Expand Down
2 changes: 1 addition & 1 deletion multi_x_serverless/deployment/client/cli/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class EnvironmentVariable(BaseModel):
name: str = Field(..., title="The name of the environment variable")
key: str = Field(..., title="The name of the environment variable")
value: str = Field(..., title="The value of the environment variable")


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
workflow_name: "{{ workflow_name }}"
environment_variables:
- name: "ENV_VAR_1"
- key: "ENV_VAR_1"
value: "value_1"
iam_policy_file: "iam_policy.json"
home_regions: [["aws", us-west-2"]] # Regions are defined as "provider:region" (e.g. aws:us-west-2)
Expand Down
1 change: 1 addition & 0 deletions multi_x_serverless/deployment/client/cli/template/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
}
],
},
func_environment_variables=[{"key": "example_key", "value": "example_value"}],
)
def first_function(event: dict[str, Any]) -> dict[str, Any]:
payload = {
Expand Down
8 changes: 5 additions & 3 deletions multi_x_serverless/deployment/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ def python_version(self) -> str:
return "python3.11"

@property
def environment_variables(self) -> dict[str, Any]:
def environment_variables(self) -> dict[str, str]:
list_of_env_variables: list[dict] = self._lookup("environment_variables")
if list_of_env_variables is None:
return {}
env_variables: dict[str, Any] = {}
env_variables: dict[str, str] = {}
for env_variable in list_of_env_variables:
env_variables[env_variable["name"]] = env_variable["value"]
if not isinstance(env_variable["value"], str):
jerryyiransun marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError("Environment variable value need to be a str")
env_variables[env_variable["key"]] = env_variable["value"]
return env_variables

@property
Expand Down
20 changes: 18 additions & 2 deletions multi_x_serverless/deployment/client/deploy/workflow_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ def build_workflow(self, config: Config) -> Workflow: # pylint: disable=too-man
else:
providers = config.regions_and_providers["providers"]
self._verify_providers(providers)

merged_env_vars = self.merge_environment_variables(
function.func_environment_variables, config.environment_variables
)
resources.append(
Function(
name=function_deployment_name,
# TODO (#22): Add function specific environment variables
environment_variables=config.environment_variables,
environment_variables=merged_env_vars,
runtime=config.python_version,
handler=function.handler,
role=function_role,
Expand Down Expand Up @@ -134,3 +137,16 @@ def get_function_role(self, config: Config, function_name: str) -> IAMRole:
filename = os.path.join(config.project_dir, ".multi-x-serverless", "iam_policy.yml")

return IAMRole(role_name=role_name, policy=filename)

def merge_environment_variables(
self, function_env_vars: list[dict[str, str]] | None, config_env_vars: dict[str, str]
jerryyiransun marked this conversation as resolved.
Show resolved Hide resolved
) -> dict[str, str]:
if not function_env_vars:
return config_env_vars

merged_env_vars: dict[str, str] = dict(config_env_vars)
# overwrite config env vars with function env vars if duplicate
for env_var in function_env_vars:
merged_env_vars[env_var["key"]] = env_var["value"]

return merged_env_vars
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ def __init__(
name: str,
entry_point: bool,
regions_and_providers: dict,
func_environment_variables: list[dict[str, str]],
):
self.function_callable = function_callable
self.name = name
self.entry_point = entry_point
self.handler = function_callable.__name__
self.regions_and_providers = regions_and_providers if len(regions_and_providers) > 0 else None
self.func_environment_variables = func_environment_variables if len(func_environment_variables) > 0 else None
self.validate_function_name()

def validate_function_name(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def register_function(
name: str,
entry_point: bool,
regions_and_providers: dict,
func_environment_variables: list[dict[str, str]],
) -> None:
"""
Register a function as a serverless function.
Expand All @@ -305,7 +306,9 @@ def register_function(
At this point we only need to register the function with the wrapper, the actual deployment will be done
later by the deployment manager.
"""
wrapper = MultiXServerlessFunction(function, name, entry_point, regions_and_providers)
wrapper = MultiXServerlessFunction(
function, name, entry_point, regions_and_providers, func_environment_variables
)
self.functions[function.__name__] = wrapper

# TODO (#22): Add function specific environment variables
Expand All @@ -314,6 +317,7 @@ def serverless_function(
name: Optional[str] = None,
entry_point: bool = False,
regions_and_providers: Optional[dict] = None,
func_environment_variables: Optional[list[dict[str, str]]] = None,
) -> Callable[..., Any]:
"""
Decorator to register a function as a Lambda function.
Expand Down Expand Up @@ -351,6 +355,21 @@ def serverless_function(
if regions_and_providers is None:
regions_and_providers = {}

if func_environment_variables is None:
func_environment_variables = []
else:
if not isinstance(func_environment_variables, list):
raise RuntimeError("func_environment_variables must be a list of dicts")
for env_variable in func_environment_variables:
if not isinstance(env_variable, dict):
raise RuntimeError("func_environment_variables must be a list of dicts")
if "key" not in env_variable or "value" not in env_variable:
raise RuntimeError("func_environment_variables must be a list of dicts with keys 'key' and 'value'")
if not isinstance(env_variable["key"], str):
raise RuntimeError("func_environment_variables must be a list of dicts with 'key' as a string")
if not isinstance(env_variable["value"], str):
raise RuntimeError("func_environment_variables must be a list of dicts with 'value' as a string")

def _register_handler(func: Callable[..., Any]) -> Callable[..., Any]:
handler_name = name if name is not None else func.__name__

Expand All @@ -377,7 +396,7 @@ def wrapper(*args, **kwargs): # type: ignore # pylint: disable=unused-argument
wrapper.routing_decision = {} # type: ignore
wrapper.entry_point = entry_point # type: ignore
wrapper.original_function = func # type: ignore
self.register_function(func, handler_name, entry_point, regions_and_providers)
self.register_function(func, handler_name, entry_point, regions_and_providers, func_environment_variables)
return wrapper

return _register_handler
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def test_build_workflow_multiple_entry_points(self):
function1.name = "function1"
function1.handler = "function1"
function1.regions_and_providers = {}
function1.func_environment_variables = {}
function2 = Mock(spec=MultiXServerlessFunction)
function2.entry_point = True
function2.name = "function2"
function2.handler = "function1"
function2.regions_and_providers = {"providers": []}
function2.func_environment_variables = {}
self.config.workflow_app.functions = {"function1": function1, "function2": function2}
with self.assertRaises(RuntimeError):
self.builder.build_workflow(self.config)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import unittest
from unittest.mock import Mock, patch
from multi_x_serverless.deployment.client.config import Config
from multi_x_serverless.deployment.client.multi_x_serverless_workflow import MultiXServerlessFunction
from multi_x_serverless.deployment.client.deploy.workflow_builder import WorkflowBuilder


class TestWorkflowBuilderFuncEnvVar(unittest.TestCase):
def test_build_func_environment_variables(self):
# function 1 (empty function level environment variables)
function1 = Mock(spec=MultiXServerlessFunction)
function1.entry_point = True
function1.name = "function1"
function1.handler = "function1"
function1.regions_and_providers = {}
function1.func_environment_variables = []

# function 2 (no overlap with global environment variables)
function2 = Mock(spec=MultiXServerlessFunction)
function2.entry_point = False
function2.name = "function2"
function2.handler = "function1"
function2.regions_and_providers = {"providers": []}
function2.func_environment_variables = [{"key": "ENV_3", "value": "function2_env_3"}]

# function 3 (overlap with global environment variables)
function3 = Mock(spec=MultiXServerlessFunction)
function3.entry_point = False
function3.name = "function2"
function3.handler = "function1"
function3.regions_and_providers = {"providers": []}
function3.func_environment_variables = [{"key": "ENV_1", "value": "function3_env_1"}]

self.builder = WorkflowBuilder()
self.config = Mock(spec=Config)
self.config.workflow_name = "test_workflow"
self.config.workflow_app.functions = {"function1": function1, "function2": function2, "function3": function3}
self.config.environment_variables = {
"ENV_1": "global_env_1",
"ENV_2": "global_env_2",
}
self.config.python_version = "3.8"
self.config.home_regions = []
self.config.project_dir = "/path/to/project"
self.config.iam_policy_file = None
self.config.regions_and_providers = {"providers": []}
self.config.workflow_app.get_successors.return_value = []

workflow = self.builder.build_workflow(self.config)

self.assertEqual(len(workflow._resources), 3)
built_func1 = workflow._resources[0]
built_func2 = workflow._resources[1]
built_func3 = workflow._resources[2]
self.assertEqual(
built_func1.environment_variables,
{
"ENV_1": "global_env_1",
"ENV_2": "global_env_2",
},
)
self.assertEqual(
built_func2.environment_variables,
{
"ENV_1": "global_env_1",
"ENV_2": "global_env_2",
"ENV_3": "function2_env_3",
},
)
self.assertEqual(
built_func3.environment_variables,
{
"ENV_1": "function3_env_1",
"ENV_2": "global_env_2",
},
)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion multi_x_serverless/tests/deployment/client/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_python_version(self):
self.assertTrue(self.config.python_version.startswith("python"))

def test_environment_variables(self):
self.config.project_config["environment_variables"] = [{"name": "ENV", "value": "test"}]
self.config.project_config["environment_variables"] = [{"key": "ENV", "value": "test"}]
self.assertEqual(self.config.environment_variables, {"ENV": "test"})

def test_home_regions(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ def function(x):
"forbidden_regions": [["aws", "us-east-2"]],
"providers": providers,
}
func_environment_variables = [{"key": "example_key", "value": "example_value"}]

function_obj = MultiXServerlessFunction(function, name, entry_point, regions_and_providers)
function_obj = MultiXServerlessFunction(
function, name, entry_point, regions_and_providers, func_environment_variables
)

self.assertEqual(function_obj.function_callable, function)
self.assertEqual(function_obj.name, name)
self.assertEqual(function_obj.entry_point, entry_point)
self.assertEqual(function_obj.handler, function.__name__)
self.assertEqual(function_obj.regions_and_providers, regions_and_providers)
self.assertEqual(function_obj.func_environment_variables, func_environment_variables)

def test_is_waiting_for_predecessors(self):
def function(x):
Expand All @@ -36,15 +40,20 @@ def function(x):
name = "test_function"
entry_point = True
regions_and_providers = {}
func_environment_variables = []

function_obj = MultiXServerlessFunction(function, name, entry_point, regions_and_providers)
function_obj = MultiXServerlessFunction(
function, name, entry_point, regions_and_providers, func_environment_variables
)

self.assertFalse(function_obj.is_waiting_for_predecessors())

def function(x):
return get_predecessor_data()

function_obj = MultiXServerlessFunction(function, name, entry_point, regions_and_providers)
function_obj = MultiXServerlessFunction(
function, name, entry_point, regions_and_providers, func_environment_variables
)

self.assertTrue(function_obj.is_waiting_for_predecessors())

Expand All @@ -55,12 +64,19 @@ def function(x):
name = "test_function"
entry_point = True
regions_and_providers = {}
func_environment_variables = []

function_obj = MultiXServerlessFunction(function, name, entry_point, regions_and_providers)
function_obj = MultiXServerlessFunction(
function, name, entry_point, regions_and_providers, func_environment_variables
)

function_obj.validate_function_name()

function_obj.name = "test:function"

with self.assertRaises(ValueError):
function_obj.validate_function_name()


if __name__ == "__main__":
unittest.main()
Loading