From 94e830214653132ded4685f2f4bc96ec63c21812 Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 17 Oct 2023 10:47:52 -0400 Subject: [PATCH] save basic report info --- sdmetrics/reports/base_report.py | 29 ++++ .../multi_table/test_quality_report.py | 28 ++++ .../single_table/test_quality_report.py | 16 ++ tests/unit/reports/test_base_report.py | 157 +++++++++++++++++- 4 files changed, 228 insertions(+), 2 deletions(-) diff --git a/sdmetrics/reports/base_report.py b/sdmetrics/reports/base_report.py index 1f67d009..472cd739 100644 --- a/sdmetrics/reports/base_report.py +++ b/sdmetrics/reports/base_report.py @@ -1,7 +1,11 @@ """Single table base report.""" import pickle import sys +import time import warnings +from copy import deepcopy +from datetime import datetime +from importlib.metadata import version import numpy as np import pandas as pd @@ -22,6 +26,11 @@ def __init__(self): self.is_generated = False self._properties = {} self._results_handler = None + self.report_info = { + 'report_type': self.__class__.__name__, + 'generated_date': None, + 'sdmetrics_version': version('sdmetrics') + } def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata): """Validate that the metadata matches the data. @@ -104,11 +113,25 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True): self.validate(real_data, synthetic_data, metadata) self.convert_datetimes(real_data, synthetic_data, metadata) + self.report_info['generated_date'] = datetime.today().strftime('%Y-%m-%d') + if 'tables' in metadata: + self.report_info['num_tables'] = len(metadata['tables']) + self.report_info['num_rows_real_data'] = { + name: len(table) for name, table in real_data.items() + } + self.report_info['num_rows_synthetic_data'] = { + name: len(table) for name, table in synthetic_data.items() + } + else: + self.report_info['num_rows_real_data'] = len(real_data) + self.report_info['num_rows_synthetic_data'] = len(synthetic_data) + scores = [] progress_bar = None if verbose: sys.stdout.write('Generating report ...\n') + start_time = time.time() for ind, (property_name, property_instance) in enumerate(self._properties.items()): if verbose: num_iterations = int(property_instance._get_num_iterations(metadata)) @@ -126,6 +149,8 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True): self._overall_score = np.nanmean(scores) self.is_generated = True + end_time = time.time() + self.report_info['generation_time'] = end_time - start_time self._handle_results(verbose) @@ -143,6 +168,10 @@ def _check_property_name(self, property_name): f" Valid property names are '{valid_property_names}'." ) + def get_info(self): + """Get the information about the report.""" + return deepcopy(self.report_info) + def _check_report_generated(self): if not self.is_generated: raise ValueError('The report has not been generated. Please call `generate` first.') diff --git a/tests/integration/reports/multi_table/test_quality_report.py b/tests/integration/reports/multi_table/test_quality_report.py index 910a4ba5..888dbd06 100644 --- a/tests/integration/reports/multi_table/test_quality_report.py +++ b/tests/integration/reports/multi_table/test_quality_report.py @@ -1,3 +1,4 @@ +import time from datetime import date, datetime import numpy as np @@ -90,7 +91,9 @@ def test_multi_table_quality_report(): # Run `generate`, `get_properties` and `get_score`, # as well as `get_visualization` and `get_details` for every property: # 'Column Shapes', 'Column Pair Trends', 'Cardinality' + generate_start_time = time.time() report.generate(real_data, synthetic_data, metadata) + generate_end_time = time.time() properties = report.get_properties() property_names = list(properties['Property']) score = report.get_score() @@ -183,6 +186,21 @@ def test_multi_table_quality_report(): }) pd.testing.assert_frame_equal(details[5], expected_df_4) + # Assert report info saved + report_info = report.get_info() + assert report_info == report.report_info + + expected_info_keys = { + 'report_type', 'generated_date', 'sdmetrics_version', 'num_tables', 'num_rows_real_data', + 'num_rows_synthetic_data', 'generation_time' + } + assert report_info.keys() == expected_info_keys + assert report_info['report_type'] == 'QualityReport' + assert report_info['num_tables'] == 2 + assert report_info['num_rows_real_data'] == {'table1': 4, 'table2': 4} + assert report_info['num_rows_synthetic_data'] == {'table1': 4, 'table2': 4} + assert report_info['generation_time'] <= generate_end_time - generate_start_time + def test_quality_report_end_to_end(): """Test the multi table QualityReport end to end.""" @@ -194,6 +212,7 @@ def test_quality_report_end_to_end(): report.generate(real_data, synthetic_data, metadata) score = report.get_score() properties = report.get_properties() + info = report.get_info() # Assert expected_properties = pd.DataFrame({ @@ -202,6 +221,15 @@ def test_quality_report_end_to_end(): }) assert score == 0.6249089638729638 pd.testing.assert_frame_equal(properties, expected_properties) + expected_info_keys = { + 'report_type', 'generated_date', 'sdmetrics_version', 'num_tables', 'num_rows_real_data', + 'num_rows_synthetic_data', 'generation_time' + } + assert info.keys() == expected_info_keys + assert info['report_type'] == 'QualityReport' + assert info['num_tables'] == 3 + assert info['num_rows_real_data'] == {'sessions': 10, 'users': 10, 'transactions': 10} + assert info['num_rows_synthetic_data'] == {'sessions': 9, 'users': 10, 'transactions': 10} def test_quality_report_with_object_datetimes(): diff --git a/tests/integration/reports/single_table/test_quality_report.py b/tests/integration/reports/single_table/test_quality_report.py index b8e87915..ede5ac5f 100644 --- a/tests/integration/reports/single_table/test_quality_report.py +++ b/tests/integration/reports/single_table/test_quality_report.py @@ -1,6 +1,7 @@ import contextlib import io import re +import time from datetime import date, datetime import numpy as np @@ -80,7 +81,9 @@ def test_report_end_to_end(self): report = QualityReport() # Run + generate_start_time = time.time() report.generate(real_data[column_names], synthetic_data[column_names], metadata) + generate_end_time = time.time() # Assert expected_details_column_shapes_dict = { @@ -126,6 +129,19 @@ def test_report_end_to_end(self): ) assert report.get_score() == 0.7804181608907237 + report_info = report.get_info() + assert report_info == report.report_info + + expected_info_keys = { + 'report_type', 'generated_date', 'sdmetrics_version', 'num_rows_real_data', + 'num_rows_synthetic_data', 'generation_time' + } + assert report_info.keys() == expected_info_keys + assert report_info['report_type'] == 'QualityReport' + assert report_info['num_rows_real_data'] == 215 + assert report_info['num_rows_synthetic_data'] == 215 + assert report_info['generation_time'] <= generate_end_time - generate_start_time + def test_quality_report_with_object_datetimes(self): """Test the quality report with object datetimes. diff --git a/tests/unit/reports/test_base_report.py b/tests/unit/reports/test_base_report.py index 90275f07..8caa59e1 100644 --- a/tests/unit/reports/test_base_report.py +++ b/tests/unit/reports/test_base_report.py @@ -152,13 +152,21 @@ def test_convert_datetimes(self): pd.testing.assert_frame_equal(real_data, expected_real_data) pd.testing.assert_frame_equal(synthetic_data, expected_synthetic_data) - def test_generate(self): + @patch('sdmetrics.reports.base_report.datetime') + @patch('sdmetrics.reports.base_report.time') + @patch('sdmetrics.reports.base_report.version') + def test_generate(self, version_mock, time_mock, datetime_mock): """Test the ``generate`` method. This test checks that the method calls the ``validate`` method and the ``get_score`` - method for each property. + method for each property. Also tests that the ``details`` property is correctly + populated. """ # Setup + datetime_mock.today.return_value = pd.to_datetime('2020-01-05') + time_mock.time.side_effect = [5, 10] + version_mock.return_value = 'version' + base_report = BaseReport() mock_validate = Mock() mock_handle_results = Mock() @@ -196,6 +204,102 @@ def test_generate(self): base_report._properties['Property 2'].get_score.assert_called_with( real_data, synthetic_data, metadata, progress_bar=None ) + expected_info = { + 'report_type': 'BaseReport', + 'generated_date': '2020-01-05', + 'sdmetrics_version': 'version', + 'num_rows_real_data': 3, + 'num_rows_synthetic_data': 3, + 'generation_time': 5 + } + assert base_report.report_info == expected_info + + @patch('sdmetrics.reports.base_report.datetime') + @patch('sdmetrics.reports.base_report.time') + @patch('sdmetrics.reports.base_report.version') + def test_generate_multi_table_details(self, version_mock, time_mock, datetime_mock): + """Test the ``generate`` method with multi-table data. + + This test checks that the ``details`` property is correctly populated with + multi-table data. + """ + # Setup + datetime_mock.today.return_value = pd.to_datetime('2020-01-05') + time_mock.time.side_effect = [5, 10] + version_mock.return_value = 'version' + + base_report = BaseReport() + base_report._handle_results = Mock() + base_report.validate = Mock() + base_report.convert_datetimes = Mock() + base_report._properties['Property 1'] = Mock() + base_report._properties['Property 1'].get_score.return_value = 1.0 + base_report._properties['Property 2'] = Mock() + base_report._properties['Property 2'].get_score.return_value = 1.0 + + real_data = { + 'table1': pd.DataFrame({ + 'column1': [1, 2, 3], + 'column2': ['a', 'b', 'c'] + }), + 'table2': pd.DataFrame({ + 'column3': ['x', 'y', 'z'], + 'column4': [10, 9, 8] + }) + } + synthetic_data = { + 'table1': pd.DataFrame({ + 'column1': [1, 2, 3], + 'column2': ['a', 'b', 'c'] + }), + 'table2': pd.DataFrame({ + 'column3': ['x', 'y', 'z'], + 'column4': [10, 9, 8] + }) + } + metadata = { + 'tables': { + 'table1': { + 'columns': { + 'column1': {'sdtype': 'numerical'}, + 'column2': {'sdtype': 'categorical'} + } + }, + 'table2': { + 'columns': { + 'column3': {'sdtype': 'categorical'}, + 'column4': {'sdtype': 'numerical'} + } + } + } + } + + # Run + base_report.generate(real_data, synthetic_data, metadata, verbose=False) + + # Assert + base_report._properties['Property 1'].get_score.assert_called_with( + real_data, synthetic_data, metadata, progress_bar=None + ) + base_report._properties['Property 2'].get_score.assert_called_with( + real_data, synthetic_data, metadata, progress_bar=None + ) + expected_info = { + 'report_type': 'BaseReport', + 'generated_date': '2020-01-05', + 'sdmetrics_version': 'version', + 'num_tables': 2, + 'num_rows_real_data': { + 'table1': 3, + 'table2': 3 + }, + 'num_rows_synthetic_data': { + 'table1': 3, + 'table2': 3 + }, + 'generation_time': 5 + } + assert base_report.report_info == expected_info def test__handle_results(self): """Test the ``_handle_results`` method.""" @@ -316,6 +420,55 @@ def test_get_properties(self): }), ) + @patch('sdmetrics.reports.base_report.datetime') + @patch('sdmetrics.reports.base_report.time') + @patch('sdmetrics.reports.base_report.version') + def test_get_info(self, version_mock, time_mock, datetime_mock): + """Test the ``get_info`` method.""" + # Setup + datetime_mock.today.return_value = pd.to_datetime('2020-01-05') + time_mock.time.side_effect = [5, 10] + version_mock.return_value = 'version' + + base_report = BaseReport() + mock_validate = Mock() + mock_handle_results = Mock() + base_report._handle_results = mock_handle_results + base_report.validate = mock_validate + base_report._properties['Property 1'] = Mock() + base_report._properties['Property 1'].get_score.return_value = 1.0 + base_report._properties['Property 2'] = Mock() + base_report._properties['Property 2'].get_score.return_value = 1.0 + + real_data = pd.DataFrame({ + 'column1': [1, 2, 3], + 'column2': ['a', 'b', 'c'] + }) + synthetic_data = pd.DataFrame({ + 'column1': [1, 2, 3], + 'column2': ['a', 'b', 'c'] + }) + metadata = { + 'columns': { + 'column1': {'sdtype': 'numerical'}, + 'column2': {'sdtype': 'categorical'} + } + } + + # Run + base_report.generate(real_data, synthetic_data, metadata, verbose=False) + + # Assert + expected_info = { + 'report_type': 'BaseReport', + 'generated_date': '2020-01-05', + 'sdmetrics_version': 'version', + 'num_rows_real_data': 3, + 'num_rows_synthetic_data': 3, + 'generation_time': 5 + } + assert base_report.get_info() == expected_info + def test_get_visualization(self): """Test the ``get_visualization`` method.""" # Setup