From 2c696ae0d0c49238bd37d3b151da3bdbdffb431e Mon Sep 17 00:00:00 2001 From: Nicolas Harlem Eide Date: Fri, 10 Dec 2021 10:59:38 +0100 Subject: [PATCH] feat: Change ECS Task to use `IntegrationPattern` for input --- src/stepfunctions/steps/compute.py | 61 +++++++++++++++--------------- tests/unit/test_compute_steps.py | 8 ++-- 2 files changed, 34 insertions(+), 35 deletions(-) diff --git a/src/stepfunctions/steps/compute.py b/src/stepfunctions/steps/compute.py index 3b8e450..654eacd 100644 --- a/src/stepfunctions/steps/compute.py +++ b/src/stepfunctions/steps/compute.py @@ -15,7 +15,8 @@ from enum import Enum from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn, \ + is_integration_pattern_valid LAMBDA_SERVICE_NAME = "lambda" GLUE_SERVICE_NAME = "glue" @@ -161,17 +162,25 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): class EcsRunTaskStep(Task): - """ Creates a Task State to run Amazon ECS or Fargate Tasks. See `Manage Amazon ECS or Fargate Tasks with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_completion=True, wait_for_callback=False, **kwargs): + supported_integration_patterns = [ + IntegrationPattern.WaitForCompletion, + IntegrationPattern.WaitForTaskToken, + IntegrationPattern.CallAndContinue + ] + + def __init__(self, state_id, wait_for_completion=True, integration_pattern=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the ecs job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the ecs job and proceed to the next step. (default: True) - wait_for_callback(bool, optional): Boolean value set to `True` if the Task state should wait for callback to resume the operation. (default: False) + integration_pattern (stepfunctions.steps.integration_resources.IntegrationPattern, optional): Service integration pattern used to call the integrated service. This is mutually exclusive from wait_for_completion Supported integration patterns (default: None): + * WaitForCompletion: Wait for the state machine execution to complete before going to the next state. (See `Run A Job `_ for more details.) + * WaitForTaskToken: Wait for the state machine execution to return a task token before progressing to the next state (See `Wait for a Callback with the Task Token `_ for more details.) + * CallAndContinue: Call StartExecution and progress to the next state (See `Request Response `_ for more details.) timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. @@ -182,31 +191,23 @@ def __init__(self, state_id, wait_for_completion=True, wait_for_callback=False, result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - if wait_for_completion and wait_for_callback: - raise ValueError("Only one of wait_for_completion and wait_for_callback can be true") - - if wait_for_callback: - """ - Example resource arn: arn:aws:states:::ecs:runTask.waitForTaskToken - """ - - kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, - EcsApi.RunTask, - IntegrationPattern.WaitForTaskToken) - elif wait_for_completion: - """ - Example resource arn: arn:aws:states:::ecs:runTask.sync - """ - - kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, - EcsApi.RunTask, - IntegrationPattern.WaitForCompletion) - else: - """ - Example resource arn: arn:aws:states:::ecs:runTask - """ - - kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, - EcsApi.RunTask) + if wait_for_completion and integration_pattern: + raise ValueError( + "Only one of wait_for_completion and integration_pattern set. " + "Set wait_for_completion to False if you wish to use integration_pattern." + ) + + # The old implementation type still has to be supported until a new + # major is realeased. + if wait_for_completion: + integration_pattern = IntegrationPattern.WaitForCompletion + if not wait_for_completion and not integration_pattern: + integration_pattern = IntegrationPattern.CallAndContinue + + is_integration_pattern_valid(integration_pattern, + self.supported_integration_patterns) + kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, + EcsApi.RunTask, + integration_pattern) super(EcsRunTaskStep, self).__init__(state_id, **kwargs) diff --git a/tests/unit/test_compute_steps.py b/tests/unit/test_compute_steps.py index 8427c35..8261396 100644 --- a/tests/unit/test_compute_steps.py +++ b/tests/unit/test_compute_steps.py @@ -17,6 +17,7 @@ from unittest.mock import patch from stepfunctions.steps.compute import LambdaStep, GlueStartJobRunStep, BatchSubmitJobStep, EcsRunTaskStep +from stepfunctions.steps.integration_resources import IntegrationPattern @patch.object(boto3.session.Session, 'region_name', 'us-east-1') @@ -102,7 +103,6 @@ def test_batch_submit_job_step_creation(): @patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_ecs_run_task_step_creation(): step = EcsRunTaskStep('Ecs Job', wait_for_completion=False) - assert step.to_dict() == { 'Type': 'Task', 'Resource': 'arn:aws:states:::ecs:runTask', @@ -110,9 +110,8 @@ def test_ecs_run_task_step_creation(): } step = EcsRunTaskStep('Ecs Job', - wait_for_callback=True, + integration_pattern=IntegrationPattern.WaitForTaskToken, wait_for_completion=False) - assert step.to_dict() == { 'Type': 'Task', 'Resource': 'arn:aws:states:::ecs:runTask.waitForTaskToken', @@ -122,7 +121,6 @@ def test_ecs_run_task_step_creation(): step = EcsRunTaskStep('Ecs Job', parameters={ 'TaskDefinition': 'Task' }) - assert step.to_dict() == { 'Type': 'Task', 'Resource': 'arn:aws:states:::ecs:runTask.sync', @@ -135,4 +133,4 @@ def test_ecs_run_task_step_creation(): with pytest.raises(ValueError): step = EcsRunTaskStep('Ecs Job', wait_for_completion=True, - wait_for_callback=True) + integration_pattern=IntegrationPattern.WaitForTaskToken)