diff --git a/src/matcha_ml/constants.py b/src/matcha_ml/constants.py index d0cf14f4..599aead4 100644 --- a/src/matcha_ml/constants.py +++ b/src/matcha_ml/constants.py @@ -19,3 +19,12 @@ }, "deployer": {"seldon": MatchaConfigComponentProperty("deployer", "seldon")}, } + +DEFAULT_STACK = [ + MatchaConfigComponentProperty("orchestrator", "zenml"), + MatchaConfigComponentProperty("experiment_tracker", "mlflow"), + MatchaConfigComponentProperty("data_version_control", "dvc"), + MatchaConfigComponentProperty("deployer", "seldon"), +] + +LLM_STACK = DEFAULT_STACK + [MatchaConfigComponentProperty("vector_database", "chroma")] diff --git a/src/matcha_ml/core/core.py b/src/matcha_ml/core/core.py index 40c12718..ed94f84e 100644 --- a/src/matcha_ml/core/core.py +++ b/src/matcha_ml/core/core.py @@ -18,18 +18,20 @@ MatchaConfigComponentProperty, MatchaConfigService, ) -from matcha_ml.constants import STACK_MODULES -from matcha_ml.core._validation import ( - is_valid_prefix, - is_valid_region, -) + +from matcha_ml.constants import DEFAULT_STACK, LLM_STACK, STACK_MODULES +from matcha_ml.core._validation import is_valid_prefix, is_valid_region from matcha_ml.errors import MatchaError, MatchaInputError from matcha_ml.runners import AzureRunner from matcha_ml.services.analytics_service import AnalyticsEvent, track from matcha_ml.services.global_parameters_service import GlobalParameters from matcha_ml.state import MatchaStateService, RemoteStateManager from matcha_ml.state.matcha_state import MatchaState -from matcha_ml.templates.azure_template import DEFAULT_STACK, LLM_STACK, AzureTemplate +from matcha_ml.templates.azure_template import ( + DEFAULT_STACK_TF, + LLM_STACK_TF, + AzureTemplate, +) class StackTypeMeta( @@ -318,7 +320,7 @@ def provision( ) azure_template = AzureTemplate( - LLM_STACK if stack_name == StackType.LLM.value else DEFAULT_STACK + LLM_STACK_TF if stack_name == StackType.LLM.value else DEFAULT_STACK_TF ) zenml_version = infer_zenml_version() @@ -353,6 +355,27 @@ def stack_set(stack_name: str) -> None: MatchaInputError: if the stack_name is not a valid stack type MatchaError: if there are already resources provisioned. """ + + def _create_stack_component(stack_type: StackType) -> MatchaConfigComponent: + """Create the set of configuration component for the stack. + + Args: + stack_type (StackType): the type of stack to create. + + Returns: + MatchaConfigComponent: the stack component. + """ + stack = MatchaConfigComponent( + name="stack", + properties=[ + MatchaConfigComponentProperty(name="name", value=stack_type.value) + ], + ) + + stack.properties += LLM_STACK if stack_type == StackType.LLM else DEFAULT_STACK + + return stack + if RemoteStateManager().is_state_provisioned(): raise MatchaError( "The remote resources are already provisioned. Changing the stack now will not " @@ -362,12 +385,7 @@ def stack_set(stack_name: str) -> None: if stack_name.lower() not in StackType: raise MatchaInputError(f"{stack_name} is not a valid stack type.") - stack_enum = StackType(stack_name.lower()) - - stack = MatchaConfigComponent( - name="stack", - properties=[MatchaConfigComponentProperty(name="name", value=stack_enum.value)], - ) + stack = _create_stack_component(stack_type=StackType(stack_name.lower())) MatchaConfigService.update(stack) diff --git a/src/matcha_ml/templates/azure_template.py b/src/matcha_ml/templates/azure_template.py index 4a170b7a..e5a40908 100644 --- a/src/matcha_ml/templates/azure_template.py +++ b/src/matcha_ml/templates/azure_template.py @@ -16,7 +16,7 @@ from matcha_ml.state import MatchaState, MatchaStateService from matcha_ml.templates.base_template import BaseTemplate, TemplateVariables -DEFAULT_STACK = [ +DEFAULT_STACK_TF = [ "aks", "resource_group", "mlflow_module", @@ -29,7 +29,7 @@ "zen_server/zenml_helm/templates", "data_version_control_storage", ] -LLM_STACK = DEFAULT_STACK + [ +LLM_STACK_TF = DEFAULT_STACK_TF + [ "chroma", "chroma/chroma_helm", "chroma/chroma_helm/templates", diff --git a/tests/test_cli/test_provision.py b/tests/test_cli/test_provision.py index de912e78..060f3013 100644 --- a/tests/test_cli/test_provision.py +++ b/tests/test_cli/test_provision.py @@ -9,7 +9,7 @@ from typer.testing import CliRunner from matcha_ml.cli.cli import app -from matcha_ml.templates.azure_template import DEFAULT_STACK +from matcha_ml.templates.azure_template import DEFAULT_STACK_TF BASE_DIR = os.path.dirname(os.path.abspath(__file__)) TEMPLATE_DIR = os.path.join( @@ -66,7 +66,7 @@ def assert_infrastructure( module_file_path = os.path.join(destination_path, module_file_name) assert os.path.exists(module_file_path) - for module_name in DEFAULT_STACK: + for module_name in DEFAULT_STACK_TF: for module_file_name in glob.glob( os.path.join(TEMPLATE_DIR, module_name, "*.tf") ): diff --git a/tests/test_cli/test_stack.py b/tests/test_cli/test_stack.py index 4a1ae998..2f5f8192 100644 --- a/tests/test_cli/test_stack.py +++ b/tests/test_cli/test_stack.py @@ -126,7 +126,17 @@ def test_stack_set_file_created( assert result.exit_code == 0 config = MatchaConfigService.read_matcha_config() - assert config.to_dict() == {"stack": {"name": "llm"}} + expected_stack = { + "stack": { + "name": "llm", + "orchestrator": "zenml", + "experiment_tracker": "mlflow", + "data_version_control": "dvc", + "deployer": "seldon", + "vector_database": "chroma", + } + } + assert config.to_dict() == expected_stack def test_stack_set_file_modified( diff --git a/tests/test_core/test_core_provision.py b/tests/test_core/test_core_provision.py index c55a82bf..9a934122 100644 --- a/tests/test_core/test_core_provision.py +++ b/tests/test_core/test_core_provision.py @@ -23,7 +23,7 @@ from matcha_ml.state.matcha_state import ( MatchaState, ) -from matcha_ml.templates.azure_template import DEFAULT_STACK +from matcha_ml.templates.azure_template import DEFAULT_STACK_TF BASE_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -154,7 +154,7 @@ def assert_infrastructure( module_file_path = os.path.join(destination_path, module_file_name) assert os.path.exists(module_file_path) - for module_name in DEFAULT_STACK: + for module_name in DEFAULT_STACK_TF: for module_file_name in glob.glob( os.path.join(TEMPLATE_DIR, module_name, "*.tf") ): @@ -329,7 +329,13 @@ def test_stale_remote_state_file_is_removed(matcha_testing_directory: str): "container_name": "test-container", "resource_group_name": "test-rg", }, - "stack": {"name": "default"}, + "stack": { + "name": "default", + "orchestrator": "zenml", + "experiment_tracker": "mlflow", + "data_version_control": "dvc", + "deployer": "seldon", + }, } with mock.patch( diff --git a/tests/test_core/test_stack_set.py b/tests/test_core/test_stack_set.py index f1ac88c6..3a416af4 100644 --- a/tests/test_core/test_stack_set.py +++ b/tests/test_core/test_stack_set.py @@ -4,33 +4,103 @@ import pytest -from matcha_ml.config import MatchaConfig, MatchaConfigService +from matcha_ml.config import ( + MatchaConfig, + MatchaConfigComponent, + MatchaConfigComponentProperty, + MatchaConfigService, +) +from matcha_ml.constants import DEFAULT_STACK, LLM_STACK from matcha_ml.core import stack_set from matcha_ml.errors import MatchaError, MatchaInputError +@pytest.fixture +def expected_matcha_config_llm_stack() -> MatchaConfig: + """A mocked version of the MatchaConfig for the LLM stack. + + Returns: + MatchaConfig: the mocked llm stack config. + """ + return MatchaConfig( + components=[ + MatchaConfigComponent( + name="stack", + properties=[MatchaConfigComponentProperty(name="name", value="llm")] + + LLM_STACK, + ) + ] + ) + + +@pytest.fixture +def expected_matcha_config_default_stack() -> MatchaConfig: + """A mocked version of the MatchaConfig for the default stack. + + Returns: + MatchaConfig: the mocked default stack config. + """ + return MatchaConfig( + components=[ + MatchaConfigComponent( + name="stack", + properties=[MatchaConfigComponentProperty(name="name", value="default")] + + DEFAULT_STACK, + ) + ] + ) + + def test_stack_set_valid_no_existing_file( - matcha_testing_directory, mocked_remote_state_manager_is_state_provisioned_false + matcha_testing_directory, + mocked_remote_state_manager_is_state_provisioned_false, + expected_matcha_config_llm_stack, ): """Test that stack_set creates a config file if one doesn't exist and that it can be read properly. Args: matcha_testing_directory (str): temporary working directory mocked_remote_state_manager_is_state_provisioned_false (RemoteStateManager): A mocked remote state manager + expected_matcha_config_llm_stack (MatchaConfig): the expected configuration if the LLM stack is used. + """ + os.chdir(matcha_testing_directory) + + stack_set(stack_name="llm") + + config = MatchaConfigService.read_matcha_config() + assert config == expected_matcha_config_llm_stack + + +def test_change_stack_expected( + matcha_testing_directory, + mocked_remote_state_manager_is_state_provisioned_false, + expected_matcha_config_llm_stack, + expected_matcha_config_default_stack, +): + """Test that when a stack is changed that the components of that stack change as expected. + + Args: + matcha_testing_directory (str): a temporary working directory. + mocked_remote_state_manager_is_state_provisioned_false (RemoteStateManager): a mocked remote state manager. + expected_matcha_config_llm_stack (MatchaConfig): the expected configuration for the LLM stack. + expected_matcha_config_default_stack (MatchaConfig): the expected configuration for the default stack. """ + # create the stack in the testing directory and assert that it's what we expect os.chdir(matcha_testing_directory) stack_set(stack_name="llm") config = MatchaConfigService.read_matcha_config() - assert config.to_dict() == {"stack": {"name": "llm"}} + assert config == expected_matcha_config_llm_stack + # TODO Having to delete the file is a bit clunky and could be improved. MatchaConfigService.delete_matcha_config() stack_set(stack_name="default") - config = MatchaConfigService.read_matcha_config() - assert config.to_dict() == {"stack": {"name": "default"}} + default_config = MatchaConfigService.read_matcha_config() + assert default_config == expected_matcha_config_default_stack + assert default_config != config def test_stack_set_invalid(