Skip to content

Commit

Permalink
Validate that the metadata is always a dict (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo authored Oct 20, 2023
1 parent bc92a5e commit 99cb1e4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def validate(self, real_data, synthetic_data, metadata):
metadata (dict):
The metadata of the table.
"""
if not isinstance(metadata, dict):
metadata = metadata.to_dict()

self._validate_metadata_matches_data(real_data, synthetic_data, metadata)

def _handle_results(self, verbose):
Expand Down Expand Up @@ -101,6 +98,9 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True):
verbose (bool):
Whether or not to print report summary and progress.
"""
if not isinstance(metadata, dict):
raise TypeError('The provided metadata is not a dictionary.')

self.validate(real_data, synthetic_data, metadata)
self.convert_datetimes(real_data, synthetic_data, metadata)

Expand Down
21 changes: 21 additions & 0 deletions tests/unit/reports/test_base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,27 @@ def test_convert_datetimes(self):
pd.testing.assert_frame_equal(real_data, expected_real_data)
pd.testing.assert_frame_equal(synthetic_data, expected_synthetic_data)

def test_generate_metadata_not_dict(self):
"""Test the ``generate`` method with metadata not being a dict."""
# Setup
base_report = BaseReport()
real_data = pd.DataFrame({
'column1': [1, 2, 3],
'column2': ['a', 'b', 'c']
})
synthetic_data = pd.DataFrame({
'column1': [1, 2, 3],
'column2': ['a', 'b', 'c']
})
metadata = 'metadata'

# Run and Assert
expected_message = (
'The provided metadata is not a dictionary.'
)
with pytest.raises(TypeError, match=expected_message):
base_report.generate(real_data, synthetic_data, metadata, verbose=False)

def test_generate(self):
"""Test the ``generate`` method.
Expand Down

0 comments on commit 99cb1e4

Please sign in to comment.