diff --git a/sdmetrics/single_table/detection/base.py b/sdmetrics/single_table/detection/base.py index da8e0eb8..ac4d9506 100644 --- a/sdmetrics/single_table/detection/base.py +++ b/sdmetrics/single_table/detection/base.py @@ -52,22 +52,17 @@ def _drop_non_compute_columns(real_data, synthetic_data, metadata): if metadata is not None: drop_columns = [] drop_columns.extend(get_alternate_keys(metadata)) - if 'columns' in metadata: - for column in metadata['columns']: - if ('primary_key' in metadata and - (column == metadata['primary_key'] or - column in metadata['primary_key'])): - drop_columns.append(column) - - for field in metadata['columns'][column]: - if field == 'sdtype': - sdtype = metadata['columns'][column][field] - if sdtype not in ['numerical', 'datetime', 'categorical']: - drop_columns.append(column) - - if field == 'pii': - if metadata['columns'][column][field]: - drop_columns.append(column) + for column in metadata.get('columns', []): + if ('primary_key' in metadata and + (column == metadata['primary_key'] or + column in metadata['primary_key'])): + drop_columns.append(column) + + column_info = metadata['columns'].get(column, {}) + sdtype = column_info.get('sdtype') + pii = column_info.get('pii') + if sdtype not in ['numerical', 'datetime', 'categorical'] or pii: + drop_columns.append(column) if drop_columns: transformed_real_data = real_data.drop(drop_columns, axis=1)