Skip to content

Commit e00da33

Browse files
committed
implement sts_token_buffer_time attribute for transport_options to update token earlier than expiration time
1 parent a0175b0 commit e00da33

File tree

2 files changed

+91
-4
lines changed

2 files changed

+91
-4
lines changed

kombu/transport/SQS.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@
7676
},
7777
}
7878
'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
79-
'sts_token_timeout': 900 # optional
79+
'sts_token_timeout': 900, # optional
80+
'sts_token_buffer_time': 0 # optional
8081
}
8182
8283
Note that FIFO and standard queues must be named accordingly (the name of
@@ -91,6 +92,9 @@
9192
sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
9293
to access with. sts_token_timeout is the token timeout, defaults (and minimum)
9394
to 900 seconds. After the mentioned period, a new token will be created.
95+
sts_token_buffer_time (seconds) is the time by which you want to refresh your token
96+
earlier than its actual expiration time, defaults to 0 (no time buffer will be added),
97+
should be less than sts_token_timeout.
9498
9599
96100
@@ -136,7 +140,7 @@
136140
import socket
137141
import string
138142
import uuid
139-
from datetime import datetime
143+
from datetime import datetime, timedelta
140144
from queue import Empty
141145

142146
from botocore.client import Config
@@ -777,10 +781,18 @@ def _handle_sts_session(self, queue, q):
777781
return self._new_predefined_queue_client_with_sts_session(queue, region)
778782
return self._predefined_queue_clients[queue]
779783

784+
def generate_sts_session_token_with_buffer(self, role_arn, token_expiry_seconds, token_buffer_seconds=0):
785+
credentials = self.generate_sts_session_token(role_arn, token_expiry_seconds)
786+
if token_buffer_seconds and token_buffer_seconds < token_expiry_seconds:
787+
credentials["Expiration"] -= timedelta(seconds=token_buffer_seconds)
788+
return credentials
789+
780790
def _new_predefined_queue_client_with_sts_session(self, queue, region):
781-
sts_creds = self.generate_sts_session_token(
791+
sts_creds = self.generate_sts_session_token_with_buffer(
782792
self.transport_options.get('sts_role_arn'),
783-
self.transport_options.get('sts_token_timeout', 900))
793+
self.transport_options.get('sts_token_timeout', 900),
794+
self.transport_options.get('sts_token_buffer_time', 0),
795+
)
784796
self.sts_expiration = sts_creds['Expiration']
785797
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
786798
region=region,

t/unit/transport/test_SQS.py

+75
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,43 @@ def test_sts_new_session(self):
936936
# Assert
937937
mock_generate_sts_session_token.assert_called_once()
938938

939+
def test_sts_new_session_with_buffer_time(self):
940+
# Arrange
941+
sts_token_timeout = 900
942+
sts_token_buffer_time = 60
943+
connection = Connection(transport=SQS.Transport, transport_options={
944+
'predefined_queues': example_predefined_queues,
945+
'sts_role_arn': 'test::arn',
946+
'sts_token_timeout': sts_token_timeout,
947+
'sts_token_buffer_time': sts_token_buffer_time,
948+
})
949+
channel = connection.channel()
950+
sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel)
951+
queue_name = 'queue-1'
952+
953+
mock_generate_sts_session_token = Mock()
954+
mock_new_sqs_client = Mock()
955+
channel.new_sqs_client = mock_new_sqs_client
956+
957+
expiration_time = datetime.utcnow() + timedelta(seconds=sts_token_timeout)
958+
959+
mock_generate_sts_session_token.side_effect = [
960+
{
961+
'Expiration': expiration_time,
962+
'SessionToken': 123,
963+
'AccessKeyId': 123,
964+
'SecretAccessKey': 123
965+
}
966+
]
967+
channel.generate_sts_session_token = mock_generate_sts_session_token
968+
969+
# Act
970+
sqs(queue=queue_name)
971+
972+
# Assert
973+
mock_generate_sts_session_token.assert_called_once()
974+
assert channel.sts_expiration == expiration_time - timedelta(seconds=sts_token_buffer_time)
975+
939976
def test_sts_session_expired(self):
940977
# Arrange
941978
connection = Connection(transport=SQS.Transport, transport_options={
@@ -966,6 +1003,44 @@ def test_sts_session_expired(self):
9661003
# Assert
9671004
mock_generate_sts_session_token.assert_called_once()
9681005

1006+
def test_sts_session_expired_with_buffer_time(self):
1007+
# Arrange
1008+
sts_token_timeout = 900
1009+
sts_token_buffer_time = 60
1010+
connection = Connection(transport=SQS.Transport, transport_options={
1011+
'predefined_queues': example_predefined_queues,
1012+
'sts_role_arn': 'test::arn',
1013+
'sts_token_timeout': sts_token_timeout,
1014+
'sts_token_buffer_time': sts_token_buffer_time,
1015+
})
1016+
channel = connection.channel()
1017+
sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel)
1018+
channel.sts_expiration = datetime.utcnow() - timedelta(days=1)
1019+
queue_name = 'queue-1'
1020+
1021+
mock_generate_sts_session_token = Mock()
1022+
mock_new_sqs_client = Mock()
1023+
channel.new_sqs_client = mock_new_sqs_client
1024+
1025+
expiration_time = datetime.utcnow() + timedelta(seconds=sts_token_timeout)
1026+
1027+
mock_generate_sts_session_token.side_effect = [
1028+
{
1029+
'Expiration': expiration_time,
1030+
'SessionToken': 123,
1031+
'AccessKeyId': 123,
1032+
'SecretAccessKey': 123
1033+
}
1034+
]
1035+
channel.generate_sts_session_token = mock_generate_sts_session_token
1036+
1037+
# Act
1038+
sqs(queue=queue_name)
1039+
1040+
# Assert
1041+
mock_generate_sts_session_token.assert_called_once()
1042+
assert channel.sts_expiration == expiration_time - timedelta(seconds=sts_token_buffer_time)
1043+
9691044
def test_sts_session_not_expired(self):
9701045
# Arrange
9711046
connection = Connection(transport=SQS.Transport, transport_options={

0 commit comments

Comments
 (0)