Skip to content

Commit

Permalink
[COST-5627] Fix Azure default extension issue (#5394)
Browse files Browse the repository at this point in the history
* Making extension parameter optional.

Co-authored-by: Luke Couzens <[email protected]>
  • Loading branch information
bacciotti and lcouzens authored Dec 2, 2024
1 parent 3d8ded0 commit 3870f77
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from masu.util.aws.common import copy_local_report_file_to_s3_bucket
from masu.util.aws.common import get_or_clear_daily_s3_by_date
from masu.util.azure import common as utils
from masu.util.azure.common import AzureBlobExtension

DATA_DIR = Config.TMP_DIR
LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -272,7 +271,7 @@ def _get_manifest(self, date_time): # noqa: C901
"""
manifest = {}
compression_mode = AzureBlobExtension.csv.value # Default value
compression_mode = None
if self.ingress_reports:
reports = [report.split(f"{self.container_name}/")[1] for report in self.ingress_reports]
year = date_time.strftime("%Y")
Expand Down Expand Up @@ -305,7 +304,7 @@ def _get_manifest(self, date_time): # noqa: C901
except AzureCostReportNotFound as ex:
json_manifest = None
msg = f"No JSON manifest exists. {ex}"
LOG.debug(msg)
LOG.info(msg)
if json_manifest:
report_name = json_manifest.name
last_modified = json_manifest.last_modified
Expand All @@ -332,9 +331,7 @@ def _get_manifest(self, date_time): # noqa: C901
manifest["reportKeys"] = [blob["blobName"] for blob in manifest_json["blobs"]]
else:
try:
blob = self._azure_client.get_latest_cost_export_for_path(
report_path, self.container_name, compression_mode
)
blob = self._azure_client.get_latest_cost_export_for_path(report_path, self.container_name)
except AzureCostReportNotFound as ex:
msg = f"Unable to find manifest. Error: {ex}"
LOG.info(log_json(self.tracing_id, msg=msg, context=self.context))
Expand Down
39 changes: 14 additions & 25 deletions koku/masu/external/downloader/azure/azure_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,26 @@ def __init__(
raise AzureServiceError("Azure Service credentials are not configured.")

def _get_latest_blob(
self, report_path: str, blobs: list[BlobProperties], extension: str
self, report_path: str, blobs: list[BlobProperties], extension: t.Optional[str] = None
) -> t.Optional[BlobProperties]:
latest_blob = None
for blob in blobs:
if not blob.name.endswith(extension):
if extension and not blob.name.endswith(extension):
continue

if report_path in blob.name and not latest_blob:
latest_blob = blob
elif report_path in blob.name and blob.last_modified > latest_blob.last_modified:
latest_blob = blob
if report_path in blob.name:
if not latest_blob or blob.last_modified > latest_blob.last_modified:
latest_blob = blob
return latest_blob

def _get_latest_blob_for_path(
self,
report_path: str,
container_name: str,
extension: str,
extension: t.Optional[str] = None,
) -> BlobProperties:
"""Get the latest file with the specified extension from given storage account container."""

latest_report = None
latest_file = None
if not container_name:
message = "Unable to gather latest file as container name is not provided."
LOG.warning(message)
Expand Down Expand Up @@ -118,15 +116,12 @@ def _get_latest_blob_for_path(
LOG.warning(error_msg)
raise AzureCostReportNotFound(message)

latest_report = self._get_latest_blob(report_path, blobs, extension)
if not latest_report:
message = (
f"No file with extension '{extension}' found in container "
f"'{container_name}' for path '{report_path}'."
)
latest_file = self._get_latest_blob(report_path, blobs, extension)
if not latest_file:
message = f"No file found in container " f"'{container_name}' for path '{report_path}'."
raise AzureCostReportNotFound(message)

return latest_report
return latest_file

def _list_blobs(self, starts_with: str, container_name: str) -> list[BlobProperties]:
try:
Expand Down Expand Up @@ -158,9 +153,7 @@ def get_file_for_key(self, key: str, container_name: str) -> BlobProperties:

return report

def get_latest_cost_export_for_path(
self, report_path: str, container_name: str, compression: str
) -> BlobProperties:
def get_latest_cost_export_for_path(self, report_path: str, container_name: str) -> BlobProperties:
"""
Get the latest cost export for a given path and container based on the compression type.
Expand All @@ -176,14 +169,10 @@ def get_latest_cost_export_for_path(
ValueError: If the compression type is not 'gzip' or 'csv'.
AzureCostReportNotFound: If no blob is found for the given path and container.
"""
valid_compressions = [AzureBlobExtension.gzip.value, AzureBlobExtension.csv.value]
if compression not in valid_compressions:
raise ValueError(f"Invalid compression type: {compression}. Expected one of: {valid_compressions}.")

return self._get_latest_blob_for_path(report_path, container_name, compression)
return self._get_latest_blob_for_path(report_path, container_name)

def get_latest_manifest_for_path(self, report_path: str, container_name: str) -> BlobProperties:
return self._get_latest_blob_for_path(report_path, container_name, AzureBlobExtension.manifest.value)
return self._get_latest_blob_for_path(report_path, container_name, AzureBlobExtension.json.value)

def download_file(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_get_ingress_manifest(self):
manifest, _ = self.ingress_downloader._get_manifest(self.mock_data.test_date)

self.assertEqual(manifest.get("reportKeys"), [self.mock_data.ingress_report])
self.assertEqual(manifest.get("Compression"), AzureBlobExtension.csv.value)
self.assertEqual(manifest.get("Compression"), None)
self.assertEqual(manifest.get("billingPeriod").get("start"), expected_start)
self.assertEqual(manifest.get("billingPeriod").get("end"), expected_end)

Expand All @@ -249,7 +249,7 @@ def test_get_manifest(self):

self.assertEqual(manifest.get("assemblyId"), self.mock_data.export_uuid)
self.assertEqual(manifest.get("reportKeys"), [self.mock_data.export_file])
self.assertEqual(manifest.get("Compression"), AzureBlobExtension.csv.value)
self.assertEqual(manifest.get("Compression"), None)
self.assertEqual(manifest.get("billingPeriod").get("start"), expected_start)
self.assertEqual(manifest.get("billingPeriod").get("end"), expected_end)

Expand Down
56 changes: 7 additions & 49 deletions koku/masu/test/external/downloader/azure/test_azure_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,15 @@ def test_get_latest_cost_export_for_path(self):
type(mock_blob).name = name_attr # kludge to set name attribute on Mock

svc = self.get_mock_client(blob_list=[mock_blob])
cost_export = svc.get_latest_cost_export_for_path(report_path, self.container_name, ".csv.gz")
cost_export = svc.get_latest_cost_export_for_path(report_path, self.container_name)
self.assertEqual(cost_export.last_modified.date(), self.current_date_time.date())

def test_get_latest_cost_export_for_path_missing(self):
"""Test that the no cost export is returned for a missing path."""
report_path = FAKE.word()
svc = self.get_mock_client()
with self.assertRaises(AzureCostReportNotFound):
svc.get_latest_cost_export_for_path(report_path, self.container_name, ".csv.gz")
svc.get_latest_cost_export_for_path(report_path, self.container_name)

def test_describe_cost_management_exports(self):
"""Test that cost management exports are returned for the account."""
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_get_latest_cost_export_http_error(self):
svc = self.get_mock_client(blob_list=[mock_blob])
svc._cloud_storage_account.get_container_client.side_effect = throw_azure_http_error
with self.assertRaises(AzureCostReportNotFound):
svc.get_latest_cost_export_for_path(report_path, self.container_name, ".csv.gz")
svc.get_latest_cost_export_for_path(report_path, self.container_name)

def test_get_latest_cost_export_http_error_403(self):
"""Test that the latest cost export catches the error for Azure HttpError 403."""
Expand All @@ -272,7 +272,7 @@ def test_get_latest_cost_export_http_error_403(self):
svc = self.get_mock_client(blob_list=[mock_blob])
svc._cloud_storage_account.get_container_client.side_effect = throw_azure_http_error_403
with self.assertRaises(AzureCostReportNotFound):
svc.get_latest_cost_export_for_path(report_path, self.container_name, ".csv.gz")
svc.get_latest_cost_export_for_path(report_path, self.container_name)

def test_get_latest_cost_export_no_container(self):
"""Test that the latest cost export catches the error for no container."""
Expand All @@ -285,7 +285,7 @@ def test_get_latest_cost_export_no_container(self):

svc = self.get_mock_client(blob_list=[mock_blob])
with self.assertRaises(AzureCostReportNotFound):
svc.get_latest_cost_export_for_path(report_path, container_name, ".csv.gz")
svc.get_latest_cost_export_for_path(report_path, container_name)

def test_get_latest_manifest_for_path(self):
"""Given a list of blobs with multiple manifests, ensure the latest one is returned"""
Expand Down Expand Up @@ -422,9 +422,7 @@ def test_get_latest_cost_export_for_path_exception(self, mock_factory):
service = AzureService(
self.tenant_id, self.client_id, self.client_secret, self.resource_group_name, self.storage_account_name
)
service.get_latest_cost_export_for_path(
report_path=FAKE.word(), container_name=FAKE.word(), compression=".csv.gz"
)
service.get_latest_cost_export_for_path(report_path=FAKE.word(), container_name=FAKE.word())

def test_describe_cost_management_exports_with_scope_and_name(self):
"""Test that cost management exports using scope and name are returned for the account."""
Expand Down Expand Up @@ -465,7 +463,6 @@ def test_get_latest_blob(self):
"""
report_path = "/container/report/path"
blobs = (
FakeBlob(f"{report_path}/_manifest.json", datetime(2022, 12, 18)),
FakeBlob(f"{report_path}/file01.csv", datetime(2022, 12, 16)),
FakeBlob(f"{report_path}/file02.csv", datetime(2022, 12, 15)),
FakeBlob("some/other/path/file01.csv", datetime(2022, 12, 1)),
Expand All @@ -487,7 +484,7 @@ def test_get_latest_cost_export_missing_container(self):
svc = self.get_mock_client(blob_list=[mock_blob])
svc._cloud_storage_account.get_container_client.side_effect = ResourceNotFoundError("Oops!")
with self.assertRaises(AzureCostReportNotFound):
svc.get_latest_cost_export_for_path(report_path, self.container_name, ".csv.gz")
svc.get_latest_cost_export_for_path(report_path, self.container_name)

@patch("masu.external.downloader.azure.azure_service.AzureClientFactory")
def test_azure_service_missing_credentials(self, mock_factory):
Expand All @@ -507,27 +504,6 @@ def test_azure_service_missing_credentials(self, mock_factory):

self.assertIn("Azure Service credentials are not configured.", str(context.exception))

@patch("masu.external.downloader.azure.azure_service.AzureClientFactory")
@patch.object(AzureService, "_get_latest_blob_for_path")
def test_get_latest_cost_export_for_path_invalid_compression(self, mock_get_latest_blob, mock_factory):
"""Test when an invalid compression type is provided and ValueError is raised."""
mock_get_latest_blob.return_value = None
mock_factory.return_value.credentials = Mock()

service = AzureService(
tenant_id="fake_tenant_id",
client_id="fake_client_id",
client_secret="fake_client_secret",
resource_group_name="fake_resource_group",
storage_account_name="fake_storage_account",
subscription_id="fake_subscription_id",
)

with self.assertRaises(ValueError) as context:
service.get_latest_cost_export_for_path("fake_report_path", "fake_container_name", "invalid_compression")

self.assertIn("Invalid compression type", str(context.exception))

@patch("masu.external.downloader.azure.azure_service.AzureService._list_blobs")
@patch("masu.external.downloader.azure.azure_service.AzureClientFactory")
@patch("masu.external.downloader.azure.azure_service.NamedTemporaryFile")
Expand Down Expand Up @@ -627,24 +603,6 @@ def test_download_file_raises_exception(self, mock_tempfile, mock_client_factory

self.assertIn("Failed to download cost export", str(context.exception))

@patch("masu.external.downloader.azure.azure_service.AzureClientFactory")
def test_invalid_compression_type(self, mock_factory):
"""Test that an invalid compression type raises an exception."""

service = AzureService(
tenant_id="fake_tenant_id",
client_id="fake_client_id",
client_secret="fake_client_secret",
resource_group_name="fake_resource_group",
storage_account_name="fake_storage_account",
subscription_id="fake_subscription_id",
)

with self.assertRaises(ValueError) as context:
service.get_latest_cost_export_for_path("fake_report_path", "fake_container_name", "invalid_compression")

self.assertIn("Invalid compression type", str(context.exception))

@patch("masu.external.downloader.azure.azure_service.AzureService._list_blobs")
@patch("masu.external.downloader.azure.azure_service.AzureClientFactory")
@patch("masu.external.downloader.azure.azure_service.NamedTemporaryFile")
Expand Down

0 comments on commit 3870f77

Please sign in to comment.