diff --git a/caribou/data_collector/components/provider/provider_retriever.py b/caribou/data_collector/components/provider/provider_retriever.py index dacf9ab9..cedea834 100644 --- a/caribou/data_collector/components/provider/provider_retriever.py +++ b/caribou/data_collector/components/provider/provider_retriever.py @@ -23,7 +23,10 @@ def __init__(self, client: RemoteClient) -> None: if self._google_api_key is None and not self._integration_test_on: raise ValueError("GOOGLE_API_KEY environment variable not set") - self._aws_ec2_client = boto3.client("ec2") # Should be available in most if not all regions + # Should be available in most if not all regions + # But just to be sure, we use us-east-1 (As we know it's available there) + self._aws_ec2_client = boto3.client("ec2", region_name="us-east-1") + self._aws_pricing_client = boto3.client("pricing", region_name="us-east-1") # Must be in us-east-1 self._aws_region_name_to_code: dict[str, str] = {} diff --git a/caribou/tests/data_collector/components/provider/test_provider_retriever.py b/caribou/tests/data_collector/components/provider/test_provider_retriever.py index 1a940646..9608b8b3 100644 --- a/caribou/tests/data_collector/components/provider/test_provider_retriever.py +++ b/caribou/tests/data_collector/components/provider/test_provider_retriever.py @@ -799,6 +799,7 @@ def test_get_aws_product_skus_empty_json(self): @patch("caribou.data_collector.components.provider.provider_retriever.boto3.client") def test_retrieve_enabled_aws_regions_success(self, mock_boto3_client): mock_ec2_client = MagicMock() + mock_boto3_client.return_value = mock_ec2_client mock_ec2_client.describe_regions.return_value = { @@ -809,7 +810,13 @@ def test_retrieve_enabled_aws_regions_success(self, mock_boto3_client): ] } - provider_retriever = ProviderRetriever(client=mock_boto3_client) + with patch("os.environ.get") as mock_os_environ_get, patch( + "caribou.common.utils.str_to_bool" + ) as mock_str_to_bool: + mock_os_environ_get.return_value = "test_key" + mock_str_to_bool.return_value = False + provider_retriever = ProviderRetriever(client=mock_boto3_client) + expected_regions = ["us-east-1", "us-west-2", "eu-west-1"] actual_regions = provider_retriever._retrieve_enabled_aws_regions() @@ -822,7 +829,13 @@ def test_retrieve_enabled_aws_regions_empty(self, mock_boto3_client): mock_ec2_client.describe_regions.return_value = {"Regions": []} - provider_retriever = ProviderRetriever(client=mock_boto3_client) + with patch("os.environ.get") as mock_os_environ_get, patch( + "caribou.common.utils.str_to_bool" + ) as mock_str_to_bool: + mock_os_environ_get.return_value = "test_key" + mock_str_to_bool.return_value = False + provider_retriever = ProviderRetriever(client=mock_boto3_client) + expected_regions = [] actual_regions = provider_retriever._retrieve_enabled_aws_regions() @@ -835,7 +848,12 @@ def test_retrieve_enabled_aws_regions_api_failure(self, mock_boto3_client): mock_ec2_client.describe_regions.side_effect = Exception("AWS API error") - provider_retriever = ProviderRetriever(client=mock_boto3_client) + with patch("os.environ.get") as mock_os_environ_get, patch( + "caribou.common.utils.str_to_bool" + ) as mock_str_to_bool: + mock_os_environ_get.return_value = "test_key" + mock_str_to_bool.return_value = False + provider_retriever = ProviderRetriever(client=mock_boto3_client) with self.assertRaises(Exception) as context: provider_retriever._retrieve_enabled_aws_regions()