From 7bff6e62b883af8142683c62e33baa82c992f953 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 4 Nov 2021 19:15:37 -0700 Subject: [PATCH] Add retry arg to pipeline constructors to add retrier to each pipeline step --- .../template/pipeline/inference.py | 15 ++++- src/stepfunctions/template/pipeline/train.py | 19 ++++--- tests/integ/test_inference_pipeline.py | 6 +- .../test_training_pipeline_estimators.py | 8 ++- ...t_training_pipeline_framework_estimator.py | 12 ++-- tests/unit/test_pipeline.py | 56 +++++++++++++++++++ 6 files changed, 96 insertions(+), 20 deletions(-) diff --git a/src/stepfunctions/template/pipeline/inference.py b/src/stepfunctions/template/pipeline/inference.py index 17a1dbe..9130bda 100644 --- a/src/stepfunctions/template/pipeline/inference.py +++ b/src/stepfunctions/template/pipeline/inference.py @@ -39,7 +39,7 @@ class InferencePipeline(WorkflowTemplate): __allowed_kwargs = ('compression_type', 'content_type', 'pipeline_name') - def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None, **kwargs): + def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None, retry=None, **kwargs): """ Args: preprocessor (sagemaker.estimator.EstimatorBase): The estimator used to preprocess and transform the training data. @@ -54,6 +54,7 @@ def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None * (list[`sagemaker.amazon.amazon_estimator.RecordSet`]) - A list of `sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data. s3_bucket (str): S3 bucket under which the output artifacts from the training job will be stored. The parent path used is built using the format: ``s3://{s3_bucket}/{pipeline_name}/models/{job_name}/``. In this format, `pipeline_name` refers to the keyword argument provided for TrainingPipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-`. Also, in the format, `job_name` refers to the job name provided when calling the :meth:`TrainingPipeline.run()` method. client (SFN.Client, optional): boto3 client to use for creating and interacting with the inference pipeline in Step Functions. (default: None) + retry (Retry): A retrier that defines the each pipeline step's retry policy. See `Error handling in Step Functions `_ for more details. (default: None) Keyword Args: compression_type (str, optional): Compression type (Gzip/None) of the file for TransformJob. (default:None) @@ -64,6 +65,7 @@ def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None self.estimator = estimator self.inputs = inputs self.s3_bucket = s3_bucket + self.retry = retry for key in self.__class__.__allowed_kwargs: setattr(self, key, kwargs.pop(key, None)) @@ -158,7 +160,7 @@ def build_workflow_definition(self): endpoint_config_name=default_name, ) - return Chain([ + steps = [ preprocessor_train_step, preprocessor_model_step, preprocessor_transform_step, @@ -166,7 +168,14 @@ def build_workflow_definition(self): pipeline_model_step, endpoint_config_step, deploy_step - ]) + ] + + if self.retry: + for step in steps: + step.add_retry(self.retry) + + return Chain(steps) + def pipeline_model_config(self, instance_type, pipeline_model): return { diff --git a/src/stepfunctions/template/pipeline/train.py b/src/stepfunctions/template/pipeline/train.py index d2bb4de..344ea5c 100644 --- a/src/stepfunctions/template/pipeline/train.py +++ b/src/stepfunctions/template/pipeline/train.py @@ -12,12 +12,7 @@ # permissions and limitations under the License. from __future__ import absolute_import -from sagemaker.utils import base_name_from_image -from sagemaker.sklearn.estimator import SKLearn -from sagemaker.model import Model -from sagemaker.pipeline import PipelineModel - -from stepfunctions.steps import TrainingStep, TransformStep, ModelStep, EndpointConfigStep, EndpointStep, Chain, Fail, Catch +from stepfunctions.steps import TrainingStep, ModelStep, EndpointConfigStep, EndpointStep, Chain, Retry from stepfunctions.workflow import Workflow from stepfunctions.template.pipeline.common import StepId, WorkflowTemplate @@ -35,7 +30,7 @@ class TrainingPipeline(WorkflowTemplate): __allowed_kwargs = ('pipeline_name',) - def __init__(self, estimator, role, inputs, s3_bucket, client=None, **kwargs): + def __init__(self, estimator, role, inputs, s3_bucket, client=None, retry=None, **kwargs): """ Args: estimator (sagemaker.estimator.EstimatorBase): The estimator to use for training. Can be a BYO estimator, Framework estimator or Amazon algorithm estimator. @@ -49,12 +44,14 @@ def __init__(self, estimator, role, inputs, s3_bucket, client=None, **kwargs): * (list[`sagemaker.amazon.amazon_estimator.RecordSet`]) - A list of `sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data. s3_bucket (str): S3 bucket under which the output artifacts from the training job will be stored. The parent path used is built using the format: ``s3://{s3_bucket}/{pipeline_name}/models/{job_name}/``. In this format, `pipeline_name` refers to the keyword argument provided for TrainingPipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-`. Also, in the format, `job_name` refers to the job name provided when calling the :meth:`TrainingPipeline.run()` method. client (SFN.Client, optional): boto3 client to use for creating and interacting with the training pipeline in Step Functions. (default: None) + retry (Retry): A retrier that defines the each pipeline step's retry policy. See `Error handling in Step Functions `_ for more details. (default: None) Keyword Args: pipeline_name (str, optional): Name of the pipeline. This name will be used to name jobs (if not provided when calling execute()), models, endpoints, and S3 objects created by the pipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-`. (default:None) """ self.estimator = estimator self.inputs = inputs + self.retry = retry for key in self.__class__.__allowed_kwargs: setattr(self, key, kwargs.pop(key, None)) @@ -110,7 +107,13 @@ def build_workflow_definition(self): endpoint_config_name=default_name, ) - return Chain([training_step, model_step, endpoint_config_step, deploy_step]) + steps = [training_step, model_step, endpoint_config_step, deploy_step] + + if self.retry: + for step in steps: + step.add_retry(self.retry) + + return Chain(steps) def execute(self, job_name=None, hyperparameters=None): """ diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index 341474c..4a6eae0 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -22,7 +22,7 @@ from stepfunctions.template.pipeline import InferencePipeline -from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES +from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES, SAGEMAKER_RETRY_STRATEGY from tests.integ.timeout import timeout from tests.integ.utils import ( state_machine_delete_wait, @@ -36,6 +36,7 @@ BASE_NAME = 'inference-pipeline-integtest' COMPRESSED_NPY_DATA = 'mnist.npy.gz' + # Fixtures @pytest.fixture(scope="module") def sklearn_preprocessor(sagemaker_role_arn, sagemaker_session): @@ -100,7 +101,8 @@ def test_inference_pipeline_framework( role=sfn_role_arn, compression_type='Gzip', content_type='application/x-npy', - pipeline_name=unique_name + pipeline_name=unique_name, + retry=SAGEMAKER_RETRY_STRATEGY ) _ = pipeline.create() diff --git a/tests/integ/test_training_pipeline_estimators.py b/tests/integ/test_training_pipeline_estimators.py index 78c8414..801cb5d 100644 --- a/tests/integ/test_training_pipeline_estimators.py +++ b/tests/integ/test_training_pipeline_estimators.py @@ -30,7 +30,7 @@ # import StepFunctions from stepfunctions.template.pipeline import TrainingPipeline -from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES +from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES, SAGEMAKER_RETRY_STRATEGY from tests.integ.timeout import timeout from tests.integ.utils import ( state_machine_delete_wait, @@ -60,7 +60,8 @@ def pca_estimator(sagemaker_role_arn): pca_estimator.mini_batch_size=128 return pca_estimator - + + @pytest.fixture(scope="module") def inputs(pca_estimator): data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") @@ -85,7 +86,8 @@ def test_pca_estimator(sfn_client, sagemaker_session, sagemaker_role_arn, sfn_ro role=sfn_role_arn, inputs=inputs, s3_bucket=bucket_name, - pipeline_name = unique_name + pipeline_name=unique_name, + retry=SAGEMAKER_RETRY_STRATEGY ) tp.create() diff --git a/tests/integ/test_training_pipeline_framework_estimator.py b/tests/integ/test_training_pipeline_framework_estimator.py index bc775a7..1dee533 100644 --- a/tests/integ/test_training_pipeline_framework_estimator.py +++ b/tests/integ/test_training_pipeline_framework_estimator.py @@ -16,7 +16,7 @@ import sagemaker import os -from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES +from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES, SAGEMAKER_RETRY_STRATEGY from tests.integ.timeout import timeout from stepfunctions.template import TrainingPipeline from sagemaker.pytorch import PyTorch @@ -29,6 +29,7 @@ get_resource_name_from_arn ) + @pytest.fixture(scope="module") def torch_estimator(sagemaker_role_arn): script_path = os.path.join(DATA_DIR, "pytorch_mnist", "mnist.py") @@ -45,6 +46,7 @@ def torch_estimator(sagemaker_role_arn): } ) + @pytest.fixture(scope="module") def sklearn_estimator(sagemaker_role_arn): script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py") @@ -103,7 +105,8 @@ def test_torch_training_pipeline(sfn_client, sagemaker_client, torch_estimator, sfn_role_arn, inputs, sagemaker_session.default_bucket(), - sfn_client + sfn_client, + retry=SAGEMAKER_RETRY_STRATEGY ) pipeline.create() # execute pipeline @@ -138,7 +141,8 @@ def test_sklearn_training_pipeline(sfn_client, sagemaker_client, sklearn_estimat sfn_role_arn, inputs, sagemaker_session.default_bucket(), - sfn_client + sfn_client, + retry=SAGEMAKER_RETRY_STRATEGY ) pipeline.create() # run pipeline @@ -154,4 +158,4 @@ def test_sklearn_training_pipeline(sfn_client, sagemaker_client, sklearn_estimat _pipeline_test_suite(sagemaker_client, training_job_name='estimator-'+endpoint_name, model_name=endpoint_name, endpoint_name=endpoint_name) # teardown - _pipeline_teardown(sfn_client, sagemaker_session, endpoint_name, pipeline) \ No newline at end of file + _pipeline_teardown(sfn_client, sagemaker_session, endpoint_name, pipeline) diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py index c7ab502..de4045e 100644 --- a/tests/unit/test_pipeline.py +++ b/tests/unit/test_pipeline.py @@ -19,6 +19,7 @@ from sagemaker.sklearn.estimator import SKLearn from unittest.mock import MagicMock, patch from stepfunctions.template import TrainingPipeline, InferencePipeline +from stepfunctions.steps import Retry from sagemaker.debugger import DebuggerHookConfig from tests.unit.utils import mock_boto_api_call @@ -27,6 +28,16 @@ STEPFUNCTIONS_EXECUTION_ROLE = 'StepFunctionsExecutionRole' PCA_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/pca:1' LINEAR_LEARNER_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1' +SAGEMAKER_RETRY_STRATEGY = Retry( + error_equals=["SageMaker.AmazonSageMakerException"], + interval_seconds=5, + max_attempts=5, + backoff_rate=2 +) +EXPECTED_RETRY = [{'BackoffRate': 2, + 'ErrorEquals': ['SageMaker.AmazonSageMakerException'], + 'IntervalSeconds': 5, + 'MaxAttempts': 5}] @pytest.fixture @@ -235,6 +246,25 @@ def test_pca_training_pipeline(pca_estimator): workflow.execute.assert_called_with(name=job_name, inputs=inputs) +@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_pca_training_pipeline_with_retry_adds_retry_to_each_step(pca_estimator): + s3_inputs = { + 'train': 's3://sagemaker/pca/train' + } + s3_bucket = 'sagemaker-us-east-1' + + pipeline = TrainingPipeline(pca_estimator, STEPFUNCTIONS_EXECUTION_ROLE, s3_inputs, s3_bucket, + retry=SAGEMAKER_RETRY_STRATEGY) + result = pipeline.workflow.definition.to_dict() + + assert result['States']['Training']['Retry'] == EXPECTED_RETRY + assert result['States']['Create Model']['Retry'] == EXPECTED_RETRY + assert result['States']['Configure Endpoint']['Retry'] == EXPECTED_RETRY + assert result['States']['Deploy']['Retry'] == EXPECTED_RETRY + + + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) @patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator): @@ -474,3 +504,29 @@ def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator): } workflow.execute.assert_called_with(name=job_name, inputs=inputs) + + +@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator): + s3_inputs = { + 'train': 's3://sagemaker-us-east-1/inference/train' + } + s3_bucket = 'sagemaker-us-east-1' + + pipeline = InferencePipeline( + preprocessor=sklearn_preprocessor, + estimator=linear_learner_estimator, + inputs=s3_inputs, + s3_bucket=s3_bucket, + role=STEPFUNCTIONS_EXECUTION_ROLE, + retry=SAGEMAKER_RETRY_STRATEGY + ) + result = pipeline.get_workflow().definition.to_dict() + + assert result['States']['Train Preprocessor']['Retry'] == EXPECTED_RETRY + assert result['States']['Create Preprocessor Model']['Retry'] == EXPECTED_RETRY + assert result['States']['Transform Input']['Retry'] == EXPECTED_RETRY + assert result['States']['Create Pipeline Model']['Retry'] == EXPECTED_RETRY + assert result['States']['Configure Endpoint']['Retry'] == EXPECTED_RETRY + assert result['States']['Deploy']['Retry'] == EXPECTED_RETRY