diff --git a/caribou/common/setup/setup_tables.py b/caribou/common/setup/setup_tables.py index 4c74f80e..5e8ed1c9 100644 --- a/caribou/common/setup/setup_tables.py +++ b/caribou/common/setup/setup_tables.py @@ -2,6 +2,7 @@ import os import boto3 +from botocore.exceptions import ClientError from caribou.common import constants @@ -20,8 +21,9 @@ def create_table(dynamodb, table_name): dynamodb.describe_table(TableName=table_name) logger.info("Table %s already exists", table_name) return - except dynamodb.exceptions.ResourceNotFoundException: - pass + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + raise if table_name not in [constants.SYNC_MESSAGES_TABLE, constants.SYNC_PREDECESSOR_COUNTER_TABLE]: # Create all non sync tables with on-demand billing mode @@ -41,7 +43,7 @@ def create_bucket(s3, bucket_name): s3.head_bucket(Bucket=bucket_name) logger.info("Bucket %s already exists", bucket_name) return - except s3.exceptions.ClientError as e: + except ClientError as e: if e.response["Error"]["Code"] != "404" and e.response["Error"]["Code"] != "403": raise s3.create_bucket( diff --git a/caribou/common/teardown/teardown_tables.py b/caribou/common/teardown/teardown_tables.py index 19713aa4..bedb988d 100644 --- a/caribou/common/teardown/teardown_tables.py +++ b/caribou/common/teardown/teardown_tables.py @@ -2,7 +2,7 @@ from typing import Any import boto3 -import botocore +from botocore.exceptions import ClientError from caribou.common import constants from caribou.common.models.endpoints import Endpoints @@ -15,7 +15,10 @@ def remove_table(dynamodb: Any, table_name: str, verbose: bool = True) -> None: # If the table exists, delete it dynamodb.delete_table(TableName=table_name) - except dynamodb.exceptions.ResourceNotFoundException: + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + # If the error is not ResourceNotFoundException, raise the exception and notify the user + raise if verbose: print(f"Table '{table_name}' does not exists (Or already removed)") @@ -34,7 +37,7 @@ def remove_bucket(s3: Any, s3_resource: Any, bucket_name: str) -> None: s3.delete_bucket(Bucket=bucket_name) print(f"Removed legacy bucket: {bucket_name}") - except botocore.exceptions.ClientError as e: + except ClientError as e: if e.response["Error"]["Code"] != "404" and e.response["Error"]["Code"] != "403": # If the error is not 403 forbidden or 404 not found, # raise the exception and notify the user @@ -106,7 +109,7 @@ def remove_sync_tables_all_regions() -> None: for table_name in sync_tables: try: remove_table(dynamodb, table_name, verbose=False) - except botocore.exceptions.ClientError as e: + except ClientError as e: # If not UnrecognizedClientException, log the error # As exception also appears if the user does not have a region enabled # Which means that there are no tables to remove anyways diff --git a/caribou/tests/common/models/remote_client/test_aws_remote_client.py b/caribou/tests/common/models/remote_client/test_aws_remote_client.py index 3c303d6d..16229b13 100644 --- a/caribou/tests/common/models/remote_client/test_aws_remote_client.py +++ b/caribou/tests/common/models/remote_client/test_aws_remote_client.py @@ -1,4 +1,6 @@ +from io import StringIO import os +import sys import unittest from unittest.mock import patch from unittest.mock import MagicMock @@ -1128,6 +1130,335 @@ def test_ecr_repository_exists(self, mock_client): # Assert that the describe_repositories method was called with the correct parameters mock_client.return_value.describe_repositories.assert_called_once_with(repositoryNames=["test_repository"]) + @patch.object(AWSRemoteClient, "_client") + def test_get_timer_rule_schedule_expression_exists(self, mock_client): + # Mocking the scenario where the timer rule exists + mock_events_client = MagicMock() + mock_client.return_value = mock_events_client + + client = AWSRemoteClient("region1") + + # Mock the return value of describe_rule + mock_events_client.describe_rule.return_value = {"ScheduleExpression": "rate(5 minutes)"} + + result = client.get_timer_rule_schedule_expression("test_rule") + + # Check that the return value is correct + self.assertEqual(result, "rate(5 minutes)") + + # Check that describe_rule was called with the correct arguments + mock_events_client.describe_rule.assert_called_once_with(Name="test_rule") + + @patch.object(AWSRemoteClient, "_client") + def test_get_timer_rule_schedule_expression_not_exists(self, mock_client): + # Mocking the scenario where the timer rule does not exist + mock_events_client = MagicMock() + mock_client.return_value = mock_events_client + + client = AWSRemoteClient("region1") + + # Mock the side effect of describe_rule to raise a ResourceNotFoundException + mock_events_client.describe_rule.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException"}}, "describe_rule" + ) + + result = client.get_timer_rule_schedule_expression("test_rule") + + # Check that the return value is None + self.assertIsNone(result) + + # Check that describe_rule was called with the correct arguments + mock_events_client.describe_rule.assert_called_once_with(Name="test_rule") + + @patch.object(AWSRemoteClient, "_client") + def test_get_timer_rule_schedule_expression_other_error(self, mock_client): + # Mocking the scenario where another error occurs + mock_events_client = MagicMock() + mock_client.return_value = mock_events_client + + client = AWSRemoteClient("region1") + + # Mock the side effect of describe_rule to raise a different ClientError + mock_events_client.describe_rule.side_effect = ClientError( + {"Error": {"Code": "InternalError"}}, "describe_rule" + ) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + result = client.get_timer_rule_schedule_expression("test_rule") + + # Reset redirect. + sys.stdout = sys.__stdout__ + + # Check that the return value is None + self.assertIsNone(result) + + # Check that describe_rule was called with the correct arguments + mock_events_client.describe_rule.assert_called_once_with(Name="test_rule") + + # Check that the error message was logged + self.assertIn( + "Error removing the EventBridge rule test_rule: An error occurred (InternalError)", + captured_output.getvalue(), + ) + + @patch.object(AWSRemoteClient, "_client") + def test_remove_timer_rule_success(self, mock_client): + # Mocking the scenario where the timer rule is removed successfully + mock_events_client = MagicMock() + mock_client.return_value = mock_events_client + + client = AWSRemoteClient("region1") + + # Call the method with test values + client.remove_timer_rule("lambda_function_name", "rule_name") + + # Check that the remove_targets and delete_rule methods were called + mock_events_client.remove_targets.assert_called_once_with(Rule="rule_name", Ids=["lambda_function_name-target"]) + mock_events_client.delete_rule.assert_called_once_with(Name="rule_name", Force=True) + + @patch.object(AWSRemoteClient, "_client") + def test_remove_timer_rule_not_found(self, mock_client): + # Mocking the scenario where the timer rule does not exist + mock_events_client = MagicMock() + mock_client.return_value = mock_events_client + + client = AWSRemoteClient("region1") + + # Mock the side effect of remove_targets to raise a ResourceNotFoundException + mock_events_client.remove_targets.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException"}}, "remove_targets" + ) + + # Call the method with test values + client.remove_timer_rule("lambda_function_name", "rule_name") + + # Check that the remove_targets method was called + mock_events_client.remove_targets.assert_called_once_with(Rule="rule_name", Ids=["lambda_function_name-target"]) + + # Check that the delete_rule method was not called + mock_events_client.delete_rule.assert_not_called() + + @patch.object(AWSRemoteClient, "_client") + def test_remove_timer_rule_other_error(self, mock_client): + # Mocking the scenario where another error occurs + mock_events_client = MagicMock() + mock_client.return_value = mock_events_client + + client = AWSRemoteClient("region1") + + # Mock the side effect of remove_targets to raise a different ClientError + mock_events_client.remove_targets.side_effect = ClientError( + {"Error": {"Code": "InternalError"}}, "remove_targets" + ) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + # Call the method with test values + client.remove_timer_rule("lambda_function_name", "rule_name") + + # Reset redirect. + sys.stdout = sys.__stdout__ + + # Check that the remove_targets method was called + mock_events_client.remove_targets.assert_called_once_with(Rule="rule_name", Ids=["lambda_function_name-target"]) + + # Check that the delete_rule method was not called + mock_events_client.delete_rule.assert_not_called() + + # Check that the error message was logged + self.assertIn( + "Error removing the EventBridge rule rule_name: An error occurred (InternalError)", + captured_output.getvalue(), + ) + + @patch.object(AWSRemoteClient, "_client") + def test_event_bridge_permission_exists_true(self, mock_client): + # Mocking the scenario where the permission exists + mock_lambda_client = MagicMock() + mock_client.return_value = mock_lambda_client + + client = AWSRemoteClient("region1") + + # Mock the return value of get_policy + mock_lambda_client.get_policy.return_value = { + "Policy": json.dumps({"Statement": [{"Sid": "existing_statement_id"}]}) + } + + result = client.event_bridge_permission_exists("lambda_function_name", "existing_statement_id") + + # Check that the return value is True + self.assertTrue(result) + + # Check that get_policy was called with the correct arguments + mock_lambda_client.get_policy.assert_called_once_with(FunctionName="lambda_function_name") + + @patch.object(AWSRemoteClient, "_client") + def test_event_bridge_permission_exists_false(self, mock_client): + # Mocking the scenario where the permission does not exist + mock_lambda_client = MagicMock() + mock_client.return_value = mock_lambda_client + + client = AWSRemoteClient("region1") + + # Mock the return value of get_policy + mock_lambda_client.get_policy.return_value = { + "Policy": json.dumps({"Statement": [{"Sid": "different_statement_id"}]}) + } + + result = client.event_bridge_permission_exists("lambda_function_name", "non_existing_statement_id") + + # Check that the return value is False + self.assertFalse(result) + + # Check that get_policy was called with the correct arguments + mock_lambda_client.get_policy.assert_called_once_with(FunctionName="lambda_function_name") + + @patch.object(AWSRemoteClient, "_client") + def test_event_bridge_permission_exists_client_error(self, mock_client): + # Mocking the scenario where a ClientError occurs + mock_lambda_client = MagicMock() + mock_client.return_value = mock_lambda_client + + client = AWSRemoteClient("region1") + + # Mock the side effect of get_policy to raise a ClientError + mock_lambda_client.get_policy.side_effect = ClientError({"Error": {"Code": "InternalError"}}, "get_policy") + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + result = client.event_bridge_permission_exists("lambda_function_name", "statement_id") + + # Reset redirect. + sys.stdout = sys.__stdout__ + + # Check that the return value is False + self.assertFalse(result) + + # Check that get_policy was called with the correct arguments + mock_lambda_client.get_policy.assert_called_once_with(FunctionName="lambda_function_name") + + # Check that the error message was logged + self.assertIn( + "Error in asserting if permission exists lambda_function_name - statement_id", captured_output.getvalue() + ) + + @patch.object(AWSRemoteClient, "_client") + @patch.object(AWSRemoteClient, "event_bridge_permission_exists") + @patch.object(AWSRemoteClient, "get_lambda_function") + def test_create_timer_rule(self, mock_get_lambda_function, mock_event_bridge_permission_exists, mock_client): + # Mocking the scenario where the timer rule is created successfully + mock_events_client = MagicMock() + mock_lambda_client = MagicMock() + mock_client.side_effect = [mock_events_client, mock_lambda_client] + + client = AWSRemoteClient("region1") + + # Define the input + lambda_function_name = "test_lambda_function" + schedule_expression = "rate(5 minutes)" + rule_name = "test_rule" + event_payload = '{"key": "value"}' + + # Mock the return value of put_rule + mock_events_client.put_rule.return_value = {"RuleArn": "arn:aws:events:region:123456789012:rule/test_rule"} + + # Mock the return value of event_bridge_permission_exists + mock_event_bridge_permission_exists.return_value = False + + # Mock the return value of get_lambda_function + mock_get_lambda_function.return_value = { + "FunctionArn": "arn:aws:lambda:region:123456789012:function:test_lambda_function" + } + + # Call the method with test values + client.create_timer_rule(lambda_function_name, schedule_expression, rule_name, event_payload) + + # Check that put_rule was called with the correct arguments + mock_events_client.put_rule.assert_called_once_with( + Name=rule_name, ScheduleExpression=schedule_expression, State="ENABLED" + ) + + # Check that add_permission was called with the correct arguments + mock_lambda_client.add_permission.assert_called_once_with( + FunctionName=lambda_function_name, + StatementId=f"{rule_name}-invoke-lambda", + Action="lambda:InvokeFunction", + Principal="events.amazonaws.com", + SourceArn="arn:aws:events:region:123456789012:rule/test_rule", + ) + + # Check that put_targets was called with the correct arguments + mock_events_client.put_targets.assert_called_once_with( + Rule=rule_name, + Targets=[ + { + "Id": f"{lambda_function_name}-target", + "Arn": "arn:aws:lambda:region:123456789012:function:test_lambda_function", + "Input": event_payload, + } + ], + ) + + @patch.object(AWSRemoteClient, "_client") + @patch.object(AWSRemoteClient, "event_bridge_permission_exists") + @patch.object(AWSRemoteClient, "get_lambda_function") + def test_create_timer_rule_permission_exists( + self, mock_get_lambda_function, mock_event_bridge_permission_exists, mock_client + ): + # Mocking the scenario where the permission already exists + mock_events_client = MagicMock() + mock_lambda_client = MagicMock() + mock_client.side_effect = [mock_events_client, mock_lambda_client] + + client = AWSRemoteClient("region1") + + # Define the input + lambda_function_name = "test_lambda_function" + schedule_expression = "rate(5 minutes)" + rule_name = "test_rule" + event_payload = '{"key": "value"}' + + # Mock the return value of put_rule + mock_events_client.put_rule.return_value = {"RuleArn": "arn:aws:events:region:123456789012:rule/test_rule"} + + # Mock the return value of event_bridge_permission_exists + mock_event_bridge_permission_exists.return_value = True + + # Mock the return value of get_lambda_function + mock_get_lambda_function.return_value = { + "FunctionArn": "arn:aws:lambda:region:123456789012:function:test_lambda_function" + } + + # Call the method with test values + client.create_timer_rule(lambda_function_name, schedule_expression, rule_name, event_payload) + + # Check that put_rule was called with the correct arguments + mock_events_client.put_rule.assert_called_once_with( + Name=rule_name, ScheduleExpression=schedule_expression, State="ENABLED" + ) + + # Check that add_permission was not called since the permission already exists + mock_lambda_client.add_permission.assert_not_called() + + # Check that put_targets was called with the correct arguments + mock_events_client.put_targets.assert_called_once_with( + Rule=rule_name, + Targets=[ + { + "Id": f"{lambda_function_name}-target", + "Arn": "arn:aws:lambda:region:123456789012:function:test_lambda_function", + "Input": event_payload, + } + ], + ) + if __name__ == "__main__": unittest.main() diff --git a/caribou/tests/common/setup/test_setup_tables.py b/caribou/tests/common/setup/test_setup_tables.py new file mode 100644 index 00000000..192b3d25 --- /dev/null +++ b/caribou/tests/common/setup/test_setup_tables.py @@ -0,0 +1,93 @@ +import unittest +from unittest.mock import patch, MagicMock +from botocore.exceptions import ClientError +from caribou.common.setup import setup_tables +from caribou.common import constants + + +class TestSetupTables(unittest.TestCase): + @patch("boto3.client") + def test_create_table_already_exists(self, mock_boto_client): + mock_dynamodb = MagicMock() + mock_boto_client.return_value = mock_dynamodb + mock_dynamodb.describe_table.return_value = {"Table": {"TableName": "existing_table"}} + + with self.assertLogs(setup_tables.logger, level="INFO") as log: + setup_tables.create_table(mock_dynamodb, "existing_table") + + mock_dynamodb.describe_table.assert_called_once_with(TableName="existing_table") + self.assertIn("Table existing_table already exists", log.output[0]) + + @patch("boto3.client") + def test_create_table_not_exists(self, mock_boto_client): + mock_dynamodb = MagicMock() + mock_boto_client.return_value = mock_dynamodb + mock_dynamodb.describe_table.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException"}}, "describe_table" + ) + + setup_tables.create_table(mock_dynamodb, "new_table") + + mock_dynamodb.create_table.assert_called_once_with( + TableName="new_table", + AttributeDefinitions=[{"AttributeName": "key", "AttributeType": "S"}], + KeySchema=[{"AttributeName": "key", "KeyType": "HASH"}], + BillingMode="PAY_PER_REQUEST", + ) + + @patch("boto3.client") + def test_create_bucket_already_exists(self, mock_boto_client): + mock_s3 = MagicMock() + mock_boto_client.return_value = mock_s3 + mock_s3.head_bucket.return_value = {} + + with self.assertLogs(setup_tables.logger, level="INFO") as log: + setup_tables.create_bucket(mock_s3, "existing_bucket") + + mock_s3.head_bucket.assert_called_once_with(Bucket="existing_bucket") + self.assertIn("Bucket existing_bucket already exists", log.output[0]) + + @patch("boto3.client") + def test_create_bucket_new_bucket(self, mock_boto_client): + mock_s3 = MagicMock() + mock_boto_client.return_value = mock_s3 + mock_s3.head_bucket.side_effect = ClientError({"Error": {"Code": "404"}}, "head_bucket") + + setup_tables.create_bucket(mock_s3, "new_bucket") + + mock_s3.head_bucket.assert_called_once_with(Bucket="new_bucket") + mock_s3.create_bucket.assert_called_once() + + @patch("boto3.client") + def test_create_bucket_other_error(self, mock_boto_client): + mock_s3 = MagicMock() + mock_boto_client.return_value = mock_s3 + mock_s3.head_bucket.side_effect = ClientError({"Error": {"Code": "500"}}, "head_bucket") + + with self.assertRaises(ClientError): + setup_tables.create_bucket(mock_s3, "error_bucket") + + mock_s3.head_bucket.assert_called_once_with(Bucket="error_bucket") + mock_s3.create_bucket.assert_not_called() + + @patch("boto3.client") + @patch("caribou.common.setup.setup_tables.create_table") + def test_main_create_tables(self, mock_create_table, mock_boto_client): + mock_dynamodb = MagicMock() + mock_boto_client.return_value = mock_dynamodb + + with patch("caribou.common.setup.setup_tables.constants") as mock_constants: + mock_constants.GLOBAL_SYSTEM_REGION = "us-west-2" + mock_constants.SYNC_MESSAGES_TABLE = "sync_messages_table" + mock_constants.SYNC_PREDECESSOR_COUNTER_TABLE = "sync_predecessor_counter_table" + mock_constants.OTHER_TABLE = "other_table" + + setup_tables.main() + + mock_create_table.assert_any_call(mock_dynamodb, "sync_messages_table") + mock_create_table.assert_any_call(mock_dynamodb, "sync_predecessor_counter_table") + mock_create_table.assert_any_call(mock_dynamodb, "other_table") + + +if __name__ == "__main__": + unittest.main() diff --git a/caribou/tests/common/teardown/test_teardown_tables.py b/caribou/tests/common/teardown/test_teardown_tables.py new file mode 100644 index 00000000..359cde14 --- /dev/null +++ b/caribou/tests/common/teardown/test_teardown_tables.py @@ -0,0 +1,151 @@ +import os +import unittest +from unittest.mock import patch, MagicMock +from botocore.exceptions import ClientError +from caribou.common import constants + +from caribou.common.teardown.teardown_tables import ( + remove_table, + remove_bucket, + teardown_framework_tables, + teardown_framework_buckets, + remove_sync_tables_all_regions, +) + + +class TestTeardownTables(unittest.TestCase): + @patch("boto3.client") + def test_remove_table_exists(self, mock_boto_client): + mock_dynamodb = MagicMock() + mock_boto_client.return_value = mock_dynamodb + + remove_table(mock_dynamodb, "test_table") + + mock_dynamodb.describe_table.assert_called_once_with(TableName="test_table") + mock_dynamodb.delete_table.assert_called_once_with(TableName="test_table") + + @patch("boto3.client") + def test_remove_table_not_exists(self, mock_boto_client): + mock_dynamodb = MagicMock() + mock_boto_client.return_value = mock_dynamodb + mock_dynamodb.describe_table.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException"}}, "describe_table" + ) + + remove_table(mock_dynamodb, "test_table") + + mock_dynamodb.describe_table.assert_called_once_with(TableName="test_table") + mock_dynamodb.delete_table.assert_not_called() + + @patch("boto3.client") + def test_remove_table_other_error(self, mock_boto_client): + mock_dynamodb = MagicMock() + mock_boto_client.return_value = mock_dynamodb + mock_dynamodb.describe_table.side_effect = ClientError( + {"Error": {"Code": "SomeOtherException"}}, "describe_table" + ) + + with self.assertRaises(ClientError): + remove_table(mock_dynamodb, "test_table") + + @patch("boto3.client") + @patch("boto3.resource") + def test_remove_bucket_exists(self, mock_boto_resource, mock_boto_client): + mock_s3 = MagicMock() + mock_s3_resource = MagicMock() + mock_boto_client.return_value = mock_s3 + mock_boto_resource.return_value = mock_s3_resource + + remove_bucket(mock_s3, mock_s3_resource, "test_bucket") + + mock_s3.head_bucket.assert_called_once_with(Bucket="test_bucket") + mock_s3_resource.Bucket.assert_called_once_with("test_bucket") + mock_s3_resource.Bucket().objects.all().delete.assert_called_once() + mock_s3.delete_bucket.assert_called_once_with(Bucket="test_bucket") + + @patch("boto3.client") + @patch("boto3.resource") + def test_remove_bucket_not_exists(self, mock_boto_resource, mock_boto_client): + mock_s3 = MagicMock() + mock_s3_resource = MagicMock() + mock_boto_client.return_value = mock_s3 + mock_boto_resource.return_value = mock_s3_resource + mock_s3.head_bucket.side_effect = ClientError({"Error": {"Code": "404"}}, "head_bucket") + + remove_bucket(mock_s3, mock_s3_resource, "test_bucket") + + mock_s3.head_bucket.assert_called_once_with(Bucket="test_bucket") + mock_s3_resource.Bucket.assert_not_called() + mock_s3.delete_bucket.assert_not_called() + + @patch("boto3.client") + @patch("boto3.resource") + def test_remove_bucket_other_error(self, mock_boto_resource, mock_boto_client): + mock_s3 = MagicMock() + mock_s3_resource = MagicMock() + mock_boto_client.return_value = mock_s3 + mock_boto_resource.return_value = mock_s3_resource + mock_s3.head_bucket.side_effect = ClientError({"Error": {"Code": "SomeOtherException"}}, "head_bucket") + + with self.assertRaises(ClientError): + remove_bucket(mock_s3, mock_s3_resource, "test_bucket") + + @patch("boto3.client") + @patch("caribou.common.teardown.teardown_tables.constants") + def test_teardown_framework_tables(self, mock_constants, mock_boto_client): + mock_dynamodb = MagicMock() + mock_boto_client.return_value = mock_dynamodb + mock_constants.GLOBAL_SYSTEM_REGION = "us-west-2" + mock_constants.SYNC_MESSAGES_TABLE = "sync_messages" + mock_constants.SYNC_PREDECESSOR_COUNTER_TABLE = "sync_predecessor_counter" + mock_constants.TEST_TABLE = "test_table" + + teardown_framework_tables() + + mock_dynamodb.describe_table.assert_called_once_with(TableName="test_table") + mock_dynamodb.delete_table.assert_called_once_with(TableName="test_table") + + @patch("boto3.client") + @patch("boto3.resource") + @patch("caribou.common.teardown.teardown_tables.constants") + def test_teardown_framework_buckets(self, mock_constants, mock_boto_resource, mock_boto_client): + mock_s3 = MagicMock() + mock_s3_resource = MagicMock() + mock_boto_client.return_value = mock_s3 + mock_boto_resource.return_value = mock_s3_resource + mock_constants.GLOBAL_SYSTEM_REGION = "us-west-2" + mock_constants.TEST_BUCKET = "test_bucket" + + teardown_framework_buckets() + + mock_s3.head_bucket.assert_called_once_with(Bucket="test_bucket") + mock_s3_resource.Bucket.assert_called_once_with("test_bucket") + mock_s3_resource.Bucket().objects.all().delete.assert_called_once() + mock_s3.delete_bucket.assert_called_once_with(Bucket="test_bucket") + + @patch("boto3.client") + @patch("caribou.common.teardown.teardown_tables.constants") + @patch("caribou.common.teardown.teardown_tables.Endpoints") + def test_remove_sync_tables_all_regions(self, mock_endpoints, mock_constants, mock_boto_client): + mock_dynamodb = MagicMock() + mock_boto_client.return_value = mock_dynamodb + mock_constants.GLOBAL_SYSTEM_REGION = "us-west-2" + mock_constants.SYNC_MESSAGES_TABLE = "sync_messages" + mock_constants.SYNC_PREDECESSOR_COUNTER_TABLE = "sync_predecessor_counter" + mock_constants.AVAILABLE_REGIONS_TABLE = "available_regions" + mock_endpoints().get_data_collector_client().get_all_values_from_table.return_value = { + "aws:us-east-1": {}, + "aws:us-west-1": {}, + } + + remove_sync_tables_all_regions() + + self.assertEqual(mock_boto_client.call_count, 3) + mock_dynamodb.describe_table.assert_any_call(TableName="sync_messages") + mock_dynamodb.describe_table.assert_any_call(TableName="sync_predecessor_counter") + mock_dynamodb.delete_table.assert_any_call(TableName="sync_messages") + mock_dynamodb.delete_table.assert_any_call(TableName="sync_predecessor_counter") + + +if __name__ == "__main__": + unittest.main()