Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Inequality CAG #2405

Open
wants to merge 9 commits into
base: feature/single-table-CAG
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdv/cag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from sdv.cag.fixed_combinations import FixedCombinations
from sdv.cag.fixed_increments import FixedIncrements
from sdv.cag.inequality import Inequality

__all__ = ('FixedCombinations', 'FixedIncrements')
__all__ = ('FixedCombinations', 'FixedIncrements', 'Inequality')
7 changes: 4 additions & 3 deletions sdv/cag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ def validate(self, data=None, metadata=None):
metadata = self.metadata

self._validate_pattern_with_metadata(metadata)
if data is not None:

if isinstance(data, pd.DataFrame):
if self._single_table:
data = {self._table_name: data}

elif isinstance(data, pd.DataFrame):
else:
table_name = self._get_single_table_name(metadata)
data = {table_name: data}

if data is not None:
self._validate_pattern_with_data(data, metadata)

def _get_updated_metadata(self, metadata):
Expand Down
320 changes: 320 additions & 0 deletions sdv/cag/inequality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
"""Inequality CAG pattern."""

import numpy as np
import pandas as pd

from sdv._utils import _convert_to_timedelta, _create_unique_name
from sdv.cag._errors import PatternNotMetError
from sdv.cag._utils import _validate_table_and_column_names
from sdv.cag.base import BasePattern
from sdv.constraints.utils import (
cast_to_datetime64,
compute_nans_column,
get_datetime_diff,
match_datetime_precision,
revert_nans_columns,
)
from sdv.metadata import Metadata


class Inequality(BasePattern):
"""Pattern that ensures `high_column_name` is greater than `low_column_name` .

The transformation works by creating a column with the difference between the
`high_column_name` and `low_column_name` columns and storing it in the
`high_column_name`'s place. The reverse transform adds the difference column
to the `low_column_name` to reconstruct the `high_column_name`.

Args:
low_column_name (str):
Name of the column that contains the low values.
high_column_name (str):
Name of the column that contains the high values.
strict_boundaries (bool):
Whether the comparison of the values should be strict ``>=`` or
not ``>``. Defaults to False.
table_name (str, optional):
The name of the table that contains the columns. Optional if the
data is only a single table. Defaults to None.
"""

@staticmethod
def _validate_init_inputs(low_column_name, high_column_name, strict_boundaries, table_name):
if not (isinstance(low_column_name, str) and isinstance(high_column_name, str)):
raise ValueError('`low_column_name` and `high_column_name` must be strings.')

if not isinstance(strict_boundaries, bool):
raise ValueError('`strict_boundaries` must be a boolean.')

if table_name and not isinstance(table_name, str):
raise ValueError('`table_name` must be a string or None.')

def __init__(self, low_column_name, high_column_name, strict_boundaries=False, table_name=None):
super().__init__()
self._validate_init_inputs(low_column_name, high_column_name, strict_boundaries, table_name)
self._low_column_name = low_column_name
self._high_column_name = high_column_name
self._diff_column_name = f'{self._low_column_name}#{self._high_column_name}'
self._operator = np.greater if strict_boundaries else np.greater_equal
self.table_name = table_name

# Set during fit
self._is_datetime = None
self._dtype = None
self._low_datetime_format = None
self._high_datetime_format = None

# Set during transform
self._nan_column_name = None

def _validate_pattern_with_metadata(self, metadata):
"""Validate the pattern is compatible with the provided metadata.

Validates that:
- If no table_name is provided, the metadata contains a single table
- Both the low_column_name and high_column_name columns exist in the table in the metadata
- Both the low_column_name and high_column_name columns have the same sdtype,
and that it is either numerical or datetime

Args:
metadata (Metadata):
The metadata to validate against.

Raises:
ValueError:
If any of the validations fail.
"""
columns = [self._low_column_name, self._high_column_name]
_validate_table_and_column_names(self.table_name, columns, metadata)
table_name = self._get_single_table_name(metadata)
for column in columns:
col_sdtype = metadata.tables[table_name].columns[column]['sdtype']
if col_sdtype not in ['numerical', 'datetime']:
raise PatternNotMetError(
f"Column '{column}' has an incompatible sdtype '{col_sdtype}'. The column "
"sdtype must be either 'numerical' or 'datetime'."
)

low_column_sdtype = metadata.tables[table_name].columns[self._low_column_name]['sdtype']
high_column_sdtype = metadata.tables[table_name].columns[self._high_column_name]['sdtype']
if low_column_sdtype != high_column_sdtype:
raise PatternNotMetError(
f"Columns '{self._low_column_name}' and '{self._high_column_name}' must have the "
f"same sdtype. Found '{low_column_sdtype}' and '{high_column_sdtype}'."
)

def _get_data(self, data):
low = data[self._low_column_name].to_numpy()
high = data[self._high_column_name].to_numpy()
return low, high

def _get_is_datetime(self, metadata, table_name):
return metadata.tables[table_name].columns[self._low_column_name]['sdtype'] == 'datetime'

def _get_datetime_format(self, metadata, table_name, column_name):
return metadata.tables[table_name].columns[column_name].get('datetime_format')

def _validate_pattern_with_data(self, data, metadata):
"""Validate the data is compatible with the pattern.

Validate that the inequality requirement is met between the high and low columns.
"""
table_name = self._get_single_table_name(metadata)
data = data[table_name]
low, high = self._get_data(data)
is_datetime = self._get_is_datetime(metadata, table_name)
if is_datetime and data[self._high_column_name].dtypes == 'O':
low_format = self._get_datetime_format(metadata, table_name, self._low_column_name)
high_format = self._get_datetime_format(metadata, table_name, self._high_column_name)
low = cast_to_datetime64(low, low_format)
high = cast_to_datetime64(high, high_format)

format_matches = bool(low_format == high_format)
if not format_matches:
low, high = match_datetime_precision(
low=low,
high=high,
low_datetime_format=low_format,
high_datetime_format=high_format,
)

valid = pd.isna(low) | pd.isna(high) | self._operator(high, low)

if not valid.all():
invalid_rows = np.where(~valid)[0]
if len(invalid_rows) <= 5:
invalid_rows_str = ', '.join(str(i) for i in invalid_rows)
else:
first_five = ', '.join(str(i) for i in invalid_rows[:5])
remaining = len(invalid_rows) - 5
invalid_rows_str = f'{first_five}, +{remaining} more'

raise PatternNotMetError(
f'The inequality requirement is not met for row indices: [{invalid_rows_str}]'
)

def _get_updated_metadata(self, metadata):
"""Get the new output metadata after applying the pattern to the input metadata."""
table_name = self._get_single_table_name(metadata)
diff_column = _create_unique_name(
self._diff_column_name, metadata.tables[table_name].columns.keys()
)

metadata = metadata.to_dict()
metadata['tables'][table_name]['columns'][diff_column] = {'sdtype': 'numerical'}
del metadata['tables'][table_name]['columns'][self._high_column_name]

metadata['tables'][table_name]['column_relationships'] = [
rel
for rel in metadata['tables'][table_name].get('column_relationships', [])
if self._high_column_name not in rel['column_names']
]

return Metadata.load_from_dict(metadata)

def _fit(self, data, metadata):
"""Fit the pattern.

Args:
data (dict[str, pd.DataFrame]):
Table data.
metadata (Metadata):
Metadata.
"""
table_name = self._get_single_table_name(metadata)
table_data = data[table_name]
self._dtype = table_data[self._high_column_name].dtypes
self._is_datetime = self._get_is_datetime(metadata, table_name)
if self._is_datetime:
self._low_datetime_format = self._get_datetime_format(
metadata, table_name, self._low_column_name
)
self._high_datetime_format = self._get_datetime_format(
metadata, table_name, self._high_column_name
)

def _transform(self, data):
"""Transform the data.

The transformation consists on replacing the `high_column_name` values with the
difference between it and the `low_column_name` values.

Afterwards, a logarithm is applied to the difference + 1 to ensure that the
value stays positive when reverted afterwards using an exponential.

Args:
data (dict[str, pd.DataFrame]):
Table data.

Returns:
dict[str, pd.DataFrame]:
Transformed data.
"""
table_name = self._get_single_table_name(self.metadata)
table_data = data[table_name]
low, high = self._get_data(table_data)
if self._is_datetime:
diff_column = get_datetime_diff(
high=high,
low=low,
high_datetime_format=self._high_datetime_format,
low_datetime_format=self._low_datetime_format,
)
else:
diff_column = high - low

self._diff_column_name = _create_unique_name(self._diff_column_name, table_data.columns)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this in fit()?

table_data[self._diff_column_name] = np.log(diff_column + 1)

nan_col = compute_nans_column(table_data, [self._low_column_name, self._high_column_name])
if nan_col is not None:
self._nan_column_name = _create_unique_name(nan_col.name, table_data.columns)
Comment on lines +230 to +231
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also move this in fit()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As was discussed here, this PR only converts the existing constraint to a CAG. I agree this is a good suggestion but I don't think it should be implemented in this PR.

table_data[self._nan_column_name] = nan_col
if self._is_datetime:
mean_value_low = table_data[self._low_column_name].mode()[0]
else:
mean_value_low = table_data[self._low_column_name].mean()
table_data = table_data.fillna({
self._low_column_name: mean_value_low,
self._diff_column_name: table_data[self._diff_column_name].mean(),
})

data[table_name] = table_data.drop(self._high_column_name, axis=1)

return data

def _reverse_transform(self, data):
"""Reverse transform the table data.

The transformation is reversed by computing an exponential of the difference value,
subtracting 1 and converting it to the original dtype. Finally, the obtained column
is added to the `low_column_name` column to get back the original `high_column_name`
value.

Args:
data (dict[str, pd.DataFrame]):
Table data.

Returns:
dict[str, pd.DataFrame]:
Transformed data.
"""
table_name = self._get_single_table_name(self.metadata)
table_data = data[table_name]
diff_column = np.exp(table_data[self._diff_column_name]) - 1
if self._dtype != np.dtype('float'):
diff_column = diff_column.round()

if self._is_datetime:
diff_column = _convert_to_timedelta(diff_column)

low = table_data[self._low_column_name].to_numpy()
if self._is_datetime and self._dtype == 'O':
low = cast_to_datetime64(low)

table_data[self._high_column_name] = pd.Series(diff_column + low).astype(self._dtype)

if self._nan_column_name and self._nan_column_name in table_data.columns:
table_data = revert_nans_columns(table_data, self._nan_column_name)

data[table_name] = table_data.drop(self._diff_column_name, axis=1)

return data

def _is_valid(self, data):
"""Check whether `high` is greater than `low` in each row.

Args:
data (dict[str, pd.DataFrame]):
Table data.

Returns:
dict[str, pd.Series]:
Whether each row is valid.
"""
table_name = self._get_single_table_name(self.metadata)
is_valid = {
table: pd.Series(True, index=table_data.index)
for table, table_data in data.items()
if table != table_name
}

table_data = data[table_name]
low, high = self._get_data(table_data)
if self._is_datetime and self._dtype == 'O':
low = cast_to_datetime64(low, self._low_datetime_format)
high = cast_to_datetime64(high, self._high_datetime_format)

format_matches = bool(self._low_datetime_format == self._high_datetime_format)
if not format_matches:
low, high = match_datetime_precision(
low=low,
high=high,
low_datetime_format=self._low_datetime_format,
high_datetime_format=self._high_datetime_format,
)

valid = pd.isna(low) | pd.isna(high) | self._operator(high, low)
is_valid[table_name] = valid

return is_valid
4 changes: 2 additions & 2 deletions sdv/constraints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def get_nan_component_value(row):

if columns_with_nans:
return ', '.join(columns_with_nans)
else:
return 'None'

return 'None'


def compute_nans_column(table_data, list_column_names):
Expand Down
Loading
Loading