Skip to content

Commit

Permalink
Added mocks to several provider retriever test, specified a region in…
Browse files Browse the repository at this point in the history
… provide retriever for integration test and just a region that for sure has ec2 client.
  • Loading branch information
Danidite committed Oct 7, 2024
1 parent c80337d commit eca66ab
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def __init__(self, client: RemoteClient) -> None:
if self._google_api_key is None and not self._integration_test_on:
raise ValueError("GOOGLE_API_KEY environment variable not set")

self._aws_ec2_client = boto3.client("ec2") # Should be available in most if not all regions
# Should be available in most if not all regions
# But just to be sure, we use us-east-1 (As we know it's available there)
self._aws_ec2_client = boto3.client("ec2", region_name="us-east-1")

self._aws_pricing_client = boto3.client("pricing", region_name="us-east-1") # Must be in us-east-1
self._aws_region_name_to_code: dict[str, str] = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ def test_get_aws_product_skus_empty_json(self):
@patch("caribou.data_collector.components.provider.provider_retriever.boto3.client")
def test_retrieve_enabled_aws_regions_success(self, mock_boto3_client):
mock_ec2_client = MagicMock()

mock_boto3_client.return_value = mock_ec2_client

mock_ec2_client.describe_regions.return_value = {
Expand All @@ -809,7 +810,13 @@ def test_retrieve_enabled_aws_regions_success(self, mock_boto3_client):
]
}

provider_retriever = ProviderRetriever(client=mock_boto3_client)
with patch("os.environ.get") as mock_os_environ_get, patch(
"caribou.common.utils.str_to_bool"
) as mock_str_to_bool:
mock_os_environ_get.return_value = "test_key"
mock_str_to_bool.return_value = False
provider_retriever = ProviderRetriever(client=mock_boto3_client)

expected_regions = ["us-east-1", "us-west-2", "eu-west-1"]

actual_regions = provider_retriever._retrieve_enabled_aws_regions()
Expand All @@ -822,7 +829,13 @@ def test_retrieve_enabled_aws_regions_empty(self, mock_boto3_client):

mock_ec2_client.describe_regions.return_value = {"Regions": []}

provider_retriever = ProviderRetriever(client=mock_boto3_client)
with patch("os.environ.get") as mock_os_environ_get, patch(
"caribou.common.utils.str_to_bool"
) as mock_str_to_bool:
mock_os_environ_get.return_value = "test_key"
mock_str_to_bool.return_value = False
provider_retriever = ProviderRetriever(client=mock_boto3_client)

expected_regions = []

actual_regions = provider_retriever._retrieve_enabled_aws_regions()
Expand All @@ -835,7 +848,12 @@ def test_retrieve_enabled_aws_regions_api_failure(self, mock_boto3_client):

mock_ec2_client.describe_regions.side_effect = Exception("AWS API error")

provider_retriever = ProviderRetriever(client=mock_boto3_client)
with patch("os.environ.get") as mock_os_environ_get, patch(
"caribou.common.utils.str_to_bool"
) as mock_str_to_bool:
mock_os_environ_get.return_value = "test_key"
mock_str_to_bool.return_value = False
provider_retriever = ProviderRetriever(client=mock_boto3_client)

with self.assertRaises(Exception) as context:
provider_retriever._retrieve_enabled_aws_regions()
Expand Down

0 comments on commit eca66ab

Please sign in to comment.