From b3a864645274473baf0b2437c03a191f2792c934 Mon Sep 17 00:00:00 2001 From: Josep Maria Salvia Hornos Date: Sun, 12 Nov 2023 21:40:20 +0100 Subject: [PATCH 1/2] Upgrade pamegranate to v1 --- sdmetrics/single_table/bayesian_network.py | 30 +++++++++++++--------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/sdmetrics/single_table/bayesian_network.py b/sdmetrics/single_table/bayesian_network.py index d2f03151..4b582f56 100644 --- a/sdmetrics/single_table/bayesian_network.py +++ b/sdmetrics/single_table/bayesian_network.py @@ -7,6 +7,8 @@ from sdmetrics.goal import Goal from sdmetrics.single_table.base import SingleTableMetric +from sklearn.preprocessing import LabelEncoder +from pomegranate import bayesian_network LOGGER = logging.getLogger(__name__) @@ -16,12 +18,6 @@ class BNLikelihoodBase(SingleTableMetric): @classmethod def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): - try: - from pomegranate import BayesianNetwork - except ImportError: - raise ImportError( - 'Please install pomegranate with `pip install pomegranate` on a version of python ' - '< 3.11. This metric is not supported on python versions >= 3.11.') real_data, synthetic_data, metadata = cls._validate_inputs( real_data, synthetic_data, metadata) @@ -30,19 +26,29 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): if not fields: return np.full(len(real_data), np.nan) + + encoders = {field: LabelEncoder() for field in fields} + + real_data_encoded = real_data.copy() + synthetic_data_encoded = synthetic_data.copy() + + for field in fields: + real_data_encoded[field] = encoders[field].fit_transform(real_data_encoded[field]) + + for field in fields: + synthetic_data_encoded[field] = encoders[field].transform(synthetic_data_encoded[field]) LOGGER.debug('Fitting the BayesianNetwork to the real data') if structure: - if isinstance(structure, dict): - structure = BayesianNetwork.from_json(json.dumps(structure)).structure - - bn = BayesianNetwork.from_structure(real_data[fields].to_numpy(), structure) + bn = bayesian_network.BayesianNetwork(structure=structure, algorithm='chow-liu') else: - bn = BayesianNetwork.from_samples(real_data[fields].to_numpy(), algorithm='chow-liu') + bn = bayesian_network.BayesianNetwork(algorithm='chow-liu') + + bn.fit(real_data_encoded[fields].to_numpy()) LOGGER.debug('Evaluating likelihood of the synthetic data') probabilities = [] - for _, row in synthetic_data[fields].iterrows(): + for _, row in synthetic_data_encoded[fields].iterrows(): try: probabilities.append(bn.probability([row.to_numpy()])) except ValueError: From 7a84f499871ba1438cba4e2986bfb8d3c9448e3d Mon Sep 17 00:00:00 2001 From: Josep Maria Salvia Hornos Date: Tue, 14 Nov 2023 10:31:04 +0100 Subject: [PATCH 2/2] Fit transform twice for missing labels --- sdmetrics/single_table/bayesian_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdmetrics/single_table/bayesian_network.py b/sdmetrics/single_table/bayesian_network.py index 4b582f56..cc2ed187 100644 --- a/sdmetrics/single_table/bayesian_network.py +++ b/sdmetrics/single_table/bayesian_network.py @@ -36,7 +36,7 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): real_data_encoded[field] = encoders[field].fit_transform(real_data_encoded[field]) for field in fields: - synthetic_data_encoded[field] = encoders[field].transform(synthetic_data_encoded[field]) + synthetic_data_encoded[field] = encoders[field].fit_transform(synthetic_data_encoded[field]) LOGGER.debug('Fitting the BayesianNetwork to the real data') if structure: