Skip to content

Commit

Permalink
_validate for multi table report
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Nov 20, 2023
1 parent c716f20 commit 3bb2392
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
4 changes: 0 additions & 4 deletions sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ def _validate(self, real_data, synthetic_data, metadata):
"""
self._validate_data_format(real_data, synthetic_data)
if self.__class__.__name__ == 'DiagnosticReport':
table_name = list(metadata.get('tables', {}).keys())
if table_name:
self.table_names = table_name

return

self._validate_metadata_matches_data(real_data, synthetic_data, metadata)
Expand Down
15 changes: 14 additions & 1 deletion sdmetrics/reports/multi_table/base_multi_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,27 @@ def _validate_relationships(self, real_data, synthetic_data, metadata):

def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata):
"""Validate that the metadata matches the data."""
self.table_names = list(metadata['tables'].keys())
for table in self.table_names:
super()._validate_metadata_matches_data(
real_data[table], synthetic_data[table], metadata['tables'][table]
)

self._validate_relationships(real_data, synthetic_data, metadata)

def _validate(self, real_data, synthetic_data, metadata):
"""Validate the inputs.
Args:
real_data (pandas.DataFrame):
The real data.
synthetic_data (pandas.DataFrame):
The synthetic data.
metadata (dict):
The metadata of the table.
"""
self.table_names = list(metadata['tables'].keys())
super()._validate(real_data, synthetic_data, metadata)

def _check_table_names(self, table_name):
if table_name not in self.table_names:
raise ValueError(f"Unknown table ('{table_name}'). Must be one of {self.table_names}.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test__validate_metadata_matches_data(self, mock__validate_metadata_matches_d

report = BaseMultiTableReport()
report._validate_relationships = mock__validate_relationships
report.table_names = ['Table_1', 'Table_2']

# Run
report._validate_metadata_matches_data(real_data, synthetic_data, metadata)
Expand Down

0 comments on commit 3bb2392

Please sign in to comment.