Skip to content

Commit

Permalink
Add Relationship Validity property (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Nov 6, 2023
1 parent 1c8d648 commit 383551b
Show file tree
Hide file tree
Showing 7 changed files with 535 additions and 5 deletions.
4 changes: 3 additions & 1 deletion sdmetrics/reports/multi_table/_properties/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sdmetrics.reports.multi_table._properties.coverage import Coverage
from sdmetrics.reports.multi_table._properties.data_validity import DataValidity
from sdmetrics.reports.multi_table._properties.inter_table_trends import InterTableTrends
from sdmetrics.reports.multi_table._properties.relationship_validity import RelationshipValidity
from sdmetrics.reports.multi_table._properties.structure import Structure
from sdmetrics.reports.multi_table._properties.synthesis import Synthesis

Expand All @@ -21,5 +22,6 @@
'InterTableTrends',
'Synthesis',
'Structure',
'DataValidity'
'DataValidity',
'RelationshipValidity',
]
8 changes: 8 additions & 0 deletions sdmetrics/reports/multi_table/_properties/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def _get_num_iterations(self, metadata):
iterations += (len(parent_columns) * len(child_columns))
return iterations

@staticmethod
def _extract_tuple(data, relation):
parent_data = data[relation['parent_table_name']]
child_data = data[relation['child_table_name']]
return (
parent_data[relation['parent_primary_key']], child_data[relation['child_foreign_key']]
)

def _compute_average(self):
"""Average the scores for each column."""
is_dataframe = isinstance(self.details, pd.DataFrame)
Expand Down
10 changes: 6 additions & 4 deletions sdmetrics/reports/multi_table/_properties/cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No
"""Get the average score of cardinality shape similarity in the given tables.
Args:
real_data (pandas.DataFrame):
The real data.
synthetic_data (pandas.DataFrame):
The synthetic data.
real_data (dict[str, pandas.DataFrame]):
The tables from the real dataset, passed as a dictionary of
table names and pandas.DataFrames.
synthetic_data (dict[str, pandas.DataFrame]):
The tables from the synthetic dataset, passed as a dictionary of
table names and pandas.DataFrames.
metadata (dict):
The metadata, which contains each column's data type as well as relationships.
progress_bar (tqdm.tqdm or None):
Expand Down
137 changes: 137 additions & 0 deletions sdmetrics/reports/multi_table/_properties/relationship_validity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import numpy as np
import pandas as pd
import plotly.express as px

from sdmetrics.column_pairs.statistical import CardinalityBoundaryAdherence, ReferentialIntegrity
from sdmetrics.reports.multi_table._properties.base import BaseMultiTableProperty
from sdmetrics.reports.utils import PlotConfig


class RelationshipValidity(BaseMultiTableProperty):
"""``Relationship Validity`` property.
This property measures the validity of the relationship
from the primary key and the foreign key perspective.
"""

_num_iteration_case = 'relationship'

def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=None):
"""Generate the _details dataframe for the relationship validity property.
Args:
real_data (dict[str, pandas.DataFrame]):
The tables from the real dataset, passed as a dictionary of
table names and pandas.DataFrames.
synthetic_data (dict[str, pandas.DataFrame]):
The tables from the synthetic dataset, passed as a dictionary of
table names and pandas.DataFrames.
metadata (dict):
The metadata, which contains each column's data type as well as relationships.
progress_bar (tqdm.tqdm or None):
The progress bar object. Defaults to ``None``.
Returns:
float:
The average score for the property for all the individual metric scores computed.
"""
child_tables, parent_tables = [], []
primary_key, foreign_key = [], []
metric_names, scores, error_messages = [], [], []
metrics = [ReferentialIntegrity, CardinalityBoundaryAdherence]
for relation in metadata.get('relationships', []):
real_columns = self._extract_tuple(real_data, relation)
synthetic_columns = self._extract_tuple(synthetic_data, relation)
for metric in metrics:
try:
relation_score = metric.compute(
real_columns,
synthetic_columns,
)
error_message = None
except Exception as e:
relation_score = np.nan
error_message = f'{type(e).__name__}: {e}'

child_tables.append(relation['child_table_name'])
parent_tables.append(relation['parent_table_name'])
primary_key.append(relation['parent_primary_key'])
foreign_key.append(relation['child_foreign_key'])
metric_names.append(metric.__name__)
scores.append(relation_score)
error_messages.append(error_message)

if progress_bar:
progress_bar.update()

self.details = pd.DataFrame({
'Parent Table': parent_tables,
'Child Table': child_tables,
'Primary Key': primary_key,
'Foreign Key': foreign_key,
'Metric': metric_names,
'Score': scores,
'Error': error_messages,
})

def _get_table_relationships_plot(self, table_name):
"""Get the table relationships plot from the parent child relationship scores for a table.
Args:
table_name (str):
Table name to get details table for.
Returns:
plotly.graph_objects._figure.Figure
"""
plot_data = self.get_details(table_name).copy()
column_name = 'Child → Parent Relationship'
plot_data[column_name] = (
plot_data['Child Table'] + ' (' + plot_data['Foreign Key'] + ') → ' +
plot_data['Parent Table']
)
plot_data = plot_data.drop(['Child Table', 'Parent Table'], axis=1)

average_score = round(plot_data['Score'].mean(), 2)

fig = px.bar(
plot_data,
x='Child → Parent Relationship',
y='Score',
title=f'Data Diagnostic: Relationship Validity (Average Score={average_score})',
color='Metric',
color_discrete_sequence=[PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN],
pattern_shape='Metric',
pattern_shape_sequence=['', '/'],
hover_name='Child → Parent Relationship',
hover_data={
'Child → Parent Relationship': False,
'Metric': True,
'Score': True,
},
barmode='group'
)

fig.update_yaxes(range=[0, 1])

fig.update_layout(
xaxis_categoryorder='total ascending',
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
font={'size': PlotConfig.FONT_SIZE}
)

return fig

def get_visualization(self, table_name):
"""Return a visualization for each score in the property.
Args:
table_name (str):
Table name to get the visualization for.
Returns:
plotly.graph_objects._figure.Figure
The visualization for the property.
"""
return self._get_table_relationships_plot(table_name)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import sys

from tqdm import tqdm

from sdmetrics.demos import load_demo
from sdmetrics.reports.multi_table._properties import RelationshipValidity


class TestRelationshipValidity:

def test_end_to_end(self):
"""Test the ``RelationshipValidity`` multi-table property end to end."""
# Setup
real_data, synthetic_data, metadata = load_demo(modality='multi_table')
relationship_validity = RelationshipValidity()

# Run
result = relationship_validity.get_score(real_data, synthetic_data, metadata)

# Assert
assert result == 1.0

def test_with_progress_bar(self, capsys):
"""Test that the progress bar is correctly updated."""
# Setup
real_data, synthetic_data, metadata = load_demo(modality='multi_table')
relationship_validity = RelationshipValidity()
num_relationship = 2

progress_bar = tqdm(total=num_relationship, file=sys.stdout)

# Run
result = relationship_validity.get_score(real_data, synthetic_data, metadata, progress_bar)
progress_bar.close()
captured = capsys.readouterr()
output = captured.out

# Assert
assert result == 1.0
assert '100%' in output
assert f'{num_relationship}/{num_relationship}' in output
30 changes: 30 additions & 0 deletions tests/unit/reports/multi_table/_properties/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,36 @@ def test__get_num_iterations(self):
base_property._num_iteration_case = 'inter_table_column_pair'
assert base_property._get_num_iterations(metadata) == 11

def test__extract_tuple(self):
"""Test the ``_extract_tuple`` method."""
# Setup
base_property = BaseMultiTableProperty()
real_user_df = pd.DataFrame({
'user_id': ['user1', 'user2'],
'columnA': ['A', 'B'],
'columnB': [np.nan, 1.0]
})
real_session_df = pd.DataFrame({
'session_id': ['session1', 'session2', 'session3'],
'user_id': ['user1', 'user1', 'user2'],
'columnC': ['X', 'Y', 'Z'],
'columnD': [4.0, 6.0, 7.0]
})

real_data = {'users': real_user_df, 'sessions': real_session_df}
relation = {
'parent_table_name': 'users',
'child_table_name': 'sessions',
'parent_primary_key': 'user_id',
'child_foreign_key': 'user_id'
}

# Run
real_columns = base_property._extract_tuple(real_data, relation)

# Assert
assert real_columns == (real_data['users']['user_id'], real_data['sessions']['user_id'])

def test__generate_details_property(self):
"""Test the ``_generate_details`` method."""
# Setup
Expand Down
Loading

0 comments on commit 383551b

Please sign in to comment.