Skip to content

Commit

Permalink
Add support for Step Functions
Browse files Browse the repository at this point in the history
  • Loading branch information
amancevice committed May 29, 2021
1 parent 9011f9b commit 00bf81c
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 21 deletions.
8 changes: 7 additions & 1 deletion main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ resource "aws_cloudwatch_event_rule" "post" {
description = local.events.post.rule_description

event_pattern = jsonencode({
detail-type = [{ prefix = "api/" }]
detail-type = ["post"]
source = [local.events.source]
})
}
Expand Down Expand Up @@ -254,6 +254,12 @@ data "aws_iam_policy_document" "inline" {

resources = ["*"]
}

statement {
sid = "SendTaskStatus"
actions = ["states:SendTask*"]
resources = ["*"]
}
}

resource "aws_iam_role" "role" {
Expand Down
4 changes: 4 additions & 0 deletions src/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def detail(self):
def detail_type(self):
return self.event['detail-type']

@property
def task_token(self):
return self.detail.get('task-token')


class HttpEvent(Event):
@property
Expand Down
11 changes: 9 additions & 2 deletions src/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from logger import logger
from secrets import export
from slack import Slack
from states import States

export(SecretId=os.getenv('SECRET_ID'))
EVENTS_BUS_NAME = os.getenv('EVENTS_BUS_NAME')
Expand Down Expand Up @@ -35,6 +36,7 @@
token=SLACK_TOKEN,
verify=not SLACK_DISABLE_VERIFICATION,
)
states = States()


@slack.route('GET /health')
Expand Down Expand Up @@ -130,8 +132,13 @@ def post_slash_cmd(event):
@logger.bind
def post(event, context=None):
event = EventBridgeEvent(event)
result = slack.post(event.detail_type, **event.detail)
events.publish(f'result/{ event.detail_type }', result)
result = slack.post(**event.detail)
if result['ok']:
events.publish('result', result)
if result['ok'] and event.task_token:
states.succeed(event.task_token, json.dumps(result))
elif event.task_token:
states.fail(event.task_token, result['error'], json.dumps(result))
return result


Expand Down
17 changes: 8 additions & 9 deletions src/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,27 @@ def install_url(self):
})
return urlunsplit(url + [query, fragment])

def post(self, path, body=None, headers=None):
def post(self, url, body=None, headers=None, **_):
# Prepare request
data = body.encode('utf-8')
headers = {k.lower(): v for k, v in (headers or {}).items()}

# Execute request
url = f'https://slack.com/{ path }'
logger.info('POST %s %s', url, body)
req = Request(url=url, data=data, headers=headers, method='POST')
res = urlopen(req)

# Parse response
resdata = res.read().decode()
ok = False
if res.headers['content-type'].startswith('application/json'):
resdata = json.loads(resdata)
ok = resdata['ok']
try:
resjson = json.loads(resdata)
except Exception: # pragma: no cover
resjson = {'ok': False}

# Log response & return
log = logger.info if ok else logger.error
log('RESPONSE [%d] %s', res.status, json.dumps(resdata))
return resdata
log = logger.info if resjson['ok'] else logger.error
log('RESPONSE [%d] %s', res.status, resdata)
return resjson

def randstate(self):
chars = string.ascii_letters + '1234567890'
Expand Down
26 changes: 26 additions & 0 deletions src/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import json

import boto3

from logger import logger


class States:
def __init__(self, boto3_session=None):
self.boto3_session = boto3_session or boto3.Session()
self.client = self.boto3_session.client('stepfunctions')

def fail(self, task_token, error, cause):
params = dict(taskToken=task_token, error=error, cause=cause)
logger.info('SEND TASK FAILURE %s', json.dumps(params))
return self.client.send_task_failure(**params)

def heartbeat(self, task_token):
params = dict(taskToken=task_token)
logger.info('SEND TASK HEARTBEAT %s', json.dumps(params))
return self.client.send_task_heartbeat(**params)

def succeed(self, task_token, output):
params = dict(taskToken=task_token, output=output)
logger.info('SEND TASK SUCCESS %s', json.dumps(params))
return self.client.send_task_success(**params)
39 changes: 31 additions & 8 deletions tests/index_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def setup(self):

index.events.publish = mock.MagicMock()
index.slack.install = mock.MagicMock()
index.slack.post = mock.MagicMock()
index.slack.install.return_value = (
{'ok': True},
'https://example.com/success',
Expand Down Expand Up @@ -242,17 +243,39 @@ def test_500(self):
assert index.proxy(event) == exp

def test_post(self):
index.slack.post.return_value = {'ok': True}
event = {
'detail-type': 'api/chat.postMessage',
'detail-type': 'post',
'detail': {
'url': 'https://slack.com/api/chat.postMessage',
'body': json.dumps({'text': 'FIZZ'}),
'headers': {'content-type': 'application/json'}
'headers': {'content-type': 'application/json; charset=utf-8'},
'task-token': '<token>',
},
}
index.slack.post = mock.MagicMock()
index.post(event)
index.slack.post.assert_called_once_with(
'api/chat.postMessage',
body=json.dumps({'text': 'FIZZ'}),
headers={'content-type': 'application/json'}
)
index.slack.post.assert_called_once_with(**{
'url': 'https://slack.com/api/chat.postMessage',
'body': json.dumps({'text': 'FIZZ'}),
'headers': {'content-type': 'application/json; charset=utf-8'},
'task-token': '<token>',
})

def test_post_fail(self):
index.slack.post.return_value = {'ok': False, 'error': 'fizz'}
event = {
'detail-type': 'post',
'detail': {
'url': 'https://slack.com/api/chat.postMessage',
'body': json.dumps({'text': 'FIZZ'}),
'headers': {'content-type': 'application/json; charset=utf-8'},
'task-token': '<token>',
},
}
index.post(event)
index.slack.post.assert_called_once_with(**{
'url': 'https://slack.com/api/chat.postMessage',
'body': json.dumps({'text': 'FIZZ'}),
'headers': {'content-type': 'application/json; charset=utf-8'},
'task-token': '<token>',
})
2 changes: 1 addition & 1 deletion tests/slack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_install_url(self, state, oauth_install_uri, exp):

def test_post(self):
ret = self.subject.post(
'api/chat.postMessage',
'https://slack.com/api/chat.postMessage',
json.dumps({'text': 'FIZZ'}),
{'content-type': 'application/json; charset=utf-8'},
)
Expand Down
30 changes: 30 additions & 0 deletions tests/states_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from unittest.mock import MagicMock

from src.states import States


class TestStates:
def setup(self):
self.boto3_session = MagicMock()
self.subject = States(boto3_session=self.boto3_session)

def test_fail(self):
self.subject.fail('<token>', 'error', '{}')
self.subject.client.send_task_failure.assert_called_once_with(
taskToken='<token>',
error='error',
cause='{}',
)

def test_heartbeat(self):
self.subject.heartbeat('<token>')
self.subject.client.send_task_heartbeat.assert_called_once_with(
taskToken='<token>',
)

def test_succeed(self):
self.subject.succeed('<token>', {'fizz': 'buzz'})
self.subject.client.send_task_success.assert_called_once_with(
taskToken='<token>',
output={'fizz': 'buzz'},
)

0 comments on commit 00bf81c

Please sign in to comment.