Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add retry to pipeline templates constructors to add retrier to each pipeline step #179

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/stepfunctions/template/pipeline/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class InferencePipeline(WorkflowTemplate):

__allowed_kwargs = ('compression_type', 'content_type', 'pipeline_name')

def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None, **kwargs):
def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None, retry=None, **kwargs):
"""
Args:
preprocessor (sagemaker.estimator.EstimatorBase): The estimator used to preprocess and transform the training data.
Expand All @@ -54,6 +54,7 @@ def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None
* (list[`sagemaker.amazon.amazon_estimator.RecordSet`]) - A list of `sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data.
s3_bucket (str): S3 bucket under which the output artifacts from the training job will be stored. The parent path used is built using the format: ``s3://{s3_bucket}/{pipeline_name}/models/{job_name}/``. In this format, `pipeline_name` refers to the keyword argument provided for TrainingPipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-<timestamp>`. Also, in the format, `job_name` refers to the job name provided when calling the :meth:`TrainingPipeline.run()` method.
client (SFN.Client, optional): boto3 client to use for creating and interacting with the inference pipeline in Step Functions. (default: None)
retry (Retry): A retrier that defines the each pipeline step's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details. (default: None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
retry (Retry): A retrier that defines the each pipeline step's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details. (default: None)
retry (Retry): A retrier that defines the retry policy for each step in the pipeline. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details. (default: None)

Any reason to not make this a list for multiple retriers?


Keyword Args:
compression_type (str, optional): Compression type (Gzip/None) of the file for TransformJob. (default:None)
Expand All @@ -64,6 +65,7 @@ def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None
self.estimator = estimator
self.inputs = inputs
self.s3_bucket = s3_bucket
self.retry = retry

for key in self.__class__.__allowed_kwargs:
setattr(self, key, kwargs.pop(key, None))
Expand Down Expand Up @@ -158,15 +160,22 @@ def build_workflow_definition(self):
endpoint_config_name=default_name,
)

return Chain([
steps = [
preprocessor_train_step,
preprocessor_model_step,
preprocessor_transform_step,
training_step,
pipeline_model_step,
endpoint_config_step,
deploy_step
])
]

if self.retry:
for step in steps:
step.add_retry(self.retry)

return Chain(steps)


def pipeline_model_config(self, instance_type, pipeline_model):
return {
Expand Down
19 changes: 11 additions & 8 deletions src/stepfunctions/template/pipeline/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
# permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker.utils import base_name_from_image
from sagemaker.sklearn.estimator import SKLearn
from sagemaker.model import Model
from sagemaker.pipeline import PipelineModel

from stepfunctions.steps import TrainingStep, TransformStep, ModelStep, EndpointConfigStep, EndpointStep, Chain, Fail, Catch
from stepfunctions.steps import TrainingStep, ModelStep, EndpointConfigStep, EndpointStep, Chain, Retry
from stepfunctions.workflow import Workflow
from stepfunctions.template.pipeline.common import StepId, WorkflowTemplate

Expand All @@ -35,7 +30,7 @@ class TrainingPipeline(WorkflowTemplate):

__allowed_kwargs = ('pipeline_name',)

def __init__(self, estimator, role, inputs, s3_bucket, client=None, **kwargs):
def __init__(self, estimator, role, inputs, s3_bucket, client=None, retry=None, **kwargs):
"""
Args:
estimator (sagemaker.estimator.EstimatorBase): The estimator to use for training. Can be a BYO estimator, Framework estimator or Amazon algorithm estimator.
Expand All @@ -49,12 +44,14 @@ def __init__(self, estimator, role, inputs, s3_bucket, client=None, **kwargs):
* (list[`sagemaker.amazon.amazon_estimator.RecordSet`]) - A list of `sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data.
s3_bucket (str): S3 bucket under which the output artifacts from the training job will be stored. The parent path used is built using the format: ``s3://{s3_bucket}/{pipeline_name}/models/{job_name}/``. In this format, `pipeline_name` refers to the keyword argument provided for TrainingPipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-<timestamp>`. Also, in the format, `job_name` refers to the job name provided when calling the :meth:`TrainingPipeline.run()` method.
client (SFN.Client, optional): boto3 client to use for creating and interacting with the training pipeline in Step Functions. (default: None)
retry (Retry): A retrier that defines the each pipeline step's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details. (default: None)

Keyword Args:
pipeline_name (str, optional): Name of the pipeline. This name will be used to name jobs (if not provided when calling execute()), models, endpoints, and S3 objects created by the pipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-<timestamp>`. (default:None)
"""
self.estimator = estimator
self.inputs = inputs
self.retry = retry

for key in self.__class__.__allowed_kwargs:
setattr(self, key, kwargs.pop(key, None))
Expand Down Expand Up @@ -110,7 +107,13 @@ def build_workflow_definition(self):
endpoint_config_name=default_name,
)

return Chain([training_step, model_step, endpoint_config_step, deploy_step])
steps = [training_step, model_step, endpoint_config_step, deploy_step]

if self.retry:
for step in steps:
step.add_retry(self.retry)

return Chain(steps)

def execute(self, job_name=None, hyperparameters=None):
"""
Expand Down
6 changes: 4 additions & 2 deletions tests/integ/test_inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from stepfunctions.template.pipeline import InferencePipeline

from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES
from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES, SAGEMAKER_RETRY_STRATEGY
from tests.integ.timeout import timeout
from tests.integ.utils import (
state_machine_delete_wait,
Expand All @@ -36,6 +36,7 @@
BASE_NAME = 'inference-pipeline-integtest'
COMPRESSED_NPY_DATA = 'mnist.npy.gz'


# Fixtures
@pytest.fixture(scope="module")
def sklearn_preprocessor(sagemaker_role_arn, sagemaker_session):
Expand Down Expand Up @@ -100,7 +101,8 @@ def test_inference_pipeline_framework(
role=sfn_role_arn,
compression_type='Gzip',
content_type='application/x-npy',
pipeline_name=unique_name
pipeline_name=unique_name,
retry=SAGEMAKER_RETRY_STRATEGY
)

_ = pipeline.create()
Expand Down
8 changes: 5 additions & 3 deletions tests/integ/test_training_pipeline_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# import StepFunctions
from stepfunctions.template.pipeline import TrainingPipeline

from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES
from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES, SAGEMAKER_RETRY_STRATEGY
from tests.integ.timeout import timeout
from tests.integ.utils import (
state_machine_delete_wait,
Expand Down Expand Up @@ -60,7 +60,8 @@ def pca_estimator(sagemaker_role_arn):
pca_estimator.mini_batch_size=128

return pca_estimator



@pytest.fixture(scope="module")
def inputs(pca_estimator):
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
Expand All @@ -85,7 +86,8 @@ def test_pca_estimator(sfn_client, sagemaker_session, sagemaker_role_arn, sfn_ro
role=sfn_role_arn,
inputs=inputs,
s3_bucket=bucket_name,
pipeline_name = unique_name
pipeline_name=unique_name,
retry=SAGEMAKER_RETRY_STRATEGY
)
tp.create()

Expand Down
12 changes: 8 additions & 4 deletions tests/integ/test_training_pipeline_framework_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sagemaker
import os

from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES
from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES, SAGEMAKER_RETRY_STRATEGY
from tests.integ.timeout import timeout
from stepfunctions.template import TrainingPipeline
from sagemaker.pytorch import PyTorch
Expand All @@ -29,6 +29,7 @@
get_resource_name_from_arn
)


@pytest.fixture(scope="module")
def torch_estimator(sagemaker_role_arn):
script_path = os.path.join(DATA_DIR, "pytorch_mnist", "mnist.py")
Expand All @@ -45,6 +46,7 @@ def torch_estimator(sagemaker_role_arn):
}
)


@pytest.fixture(scope="module")
def sklearn_estimator(sagemaker_role_arn):
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
Expand Down Expand Up @@ -103,7 +105,8 @@ def test_torch_training_pipeline(sfn_client, sagemaker_client, torch_estimator,
sfn_role_arn,
inputs,
sagemaker_session.default_bucket(),
sfn_client
sfn_client,
retry=SAGEMAKER_RETRY_STRATEGY
)
pipeline.create()
# execute pipeline
Expand Down Expand Up @@ -138,7 +141,8 @@ def test_sklearn_training_pipeline(sfn_client, sagemaker_client, sklearn_estimat
sfn_role_arn,
inputs,
sagemaker_session.default_bucket(),
sfn_client
sfn_client,
retry=SAGEMAKER_RETRY_STRATEGY
)
pipeline.create()
# run pipeline
Expand All @@ -154,4 +158,4 @@ def test_sklearn_training_pipeline(sfn_client, sagemaker_client, sklearn_estimat
_pipeline_test_suite(sagemaker_client, training_job_name='estimator-'+endpoint_name, model_name=endpoint_name, endpoint_name=endpoint_name)

# teardown
_pipeline_teardown(sfn_client, sagemaker_session, endpoint_name, pipeline)
_pipeline_teardown(sfn_client, sagemaker_session, endpoint_name, pipeline)
56 changes: 56 additions & 0 deletions tests/unit/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sagemaker.sklearn.estimator import SKLearn
from unittest.mock import MagicMock, patch
from stepfunctions.template import TrainingPipeline, InferencePipeline
from stepfunctions.steps import Retry
from sagemaker.debugger import DebuggerHookConfig

from tests.unit.utils import mock_boto_api_call
Expand All @@ -27,6 +28,16 @@
STEPFUNCTIONS_EXECUTION_ROLE = 'StepFunctionsExecutionRole'
PCA_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/pca:1'
LINEAR_LEARNER_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1'
SAGEMAKER_RETRY_STRATEGY = Retry(
error_equals=["SageMaker.AmazonSageMakerException"],
interval_seconds=5,
max_attempts=5,
backoff_rate=2
)
EXPECTED_RETRY = [{'BackoffRate': 2,
'ErrorEquals': ['SageMaker.AmazonSageMakerException'],
'IntervalSeconds': 5,
'MaxAttempts': 5}]


@pytest.fixture
Expand Down Expand Up @@ -235,6 +246,25 @@ def test_pca_training_pipeline(pca_estimator):
workflow.execute.assert_called_with(name=job_name, inputs=inputs)


@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_pca_training_pipeline_with_retry_adds_retry_to_each_step(pca_estimator):
s3_inputs = {
'train': 's3://sagemaker/pca/train'
}
s3_bucket = 'sagemaker-us-east-1'

pipeline = TrainingPipeline(pca_estimator, STEPFUNCTIONS_EXECUTION_ROLE, s3_inputs, s3_bucket,
retry=SAGEMAKER_RETRY_STRATEGY)
result = pipeline.workflow.definition.to_dict()

assert result['States']['Training']['Retry'] == EXPECTED_RETRY
assert result['States']['Create Model']['Retry'] == EXPECTED_RETRY
assert result['States']['Configure Endpoint']['Retry'] == EXPECTED_RETRY
assert result['States']['Deploy']['Retry'] == EXPECTED_RETRY



@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator):
Expand Down Expand Up @@ -474,3 +504,29 @@ def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator):
}

workflow.execute.assert_called_with(name=job_name, inputs=inputs)


@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator):
s3_inputs = {
'train': 's3://sagemaker-us-east-1/inference/train'
}
s3_bucket = 'sagemaker-us-east-1'

pipeline = InferencePipeline(
preprocessor=sklearn_preprocessor,
estimator=linear_learner_estimator,
inputs=s3_inputs,
s3_bucket=s3_bucket,
role=STEPFUNCTIONS_EXECUTION_ROLE,
retry=SAGEMAKER_RETRY_STRATEGY
)
result = pipeline.get_workflow().definition.to_dict()

assert result['States']['Train Preprocessor']['Retry'] == EXPECTED_RETRY
assert result['States']['Create Preprocessor Model']['Retry'] == EXPECTED_RETRY
assert result['States']['Transform Input']['Retry'] == EXPECTED_RETRY
assert result['States']['Create Pipeline Model']['Retry'] == EXPECTED_RETRY
assert result['States']['Configure Endpoint']['Retry'] == EXPECTED_RETRY
assert result['States']['Deploy']['Retry'] == EXPECTED_RETRY