Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Mar 6, 2025
1 parent 029406b commit 74151e2
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 708 deletions.
259 changes: 0 additions & 259 deletions sdv/cag/_utils.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,5 @@
import re

import numpy as np
import pandas as pd

from sdv.cag._errors import PatternNotMetError

PRECISION_LEVELS = {
'%Y': 1, # Year
'%y': 1, # Year without century (same precision as %Y)
'%B': 2, # Full month name
'%b': 2, # Abbreviated month name
'%m': 2, # Month as a number
'%d': 3, # Day of the month
'%j': 3, # Day of the year
'%U': 3, # Week number (Sunday-starting)
'%W': 3, # Week number (Monday-starting)
'%A': 3, # Full weekday name
'%a': 3, # Abbreviated weekday name
'%w': 3, # Weekday as a decimal
'%H': 4, # Hour (24-hour clock)
'%I': 4, # Hour (12-hour clock)
'%M': 5, # Minute
'%S': 6, # Second
'%f': 7, # Microsecond
# Formats that don't add precision
'%p': 0, # AM/PM
'%z': 0, # UTC offset
'%Z': 0, # Time zone name
'%c': 0, # Locale-based date/time
'%x': 0, # Locale-based date
'%X': 0, # Locale-based time
}


def _validate_table_and_column_names(table_name, columns, metadata):
"""Validate the table and column names for the pattern."""
Expand All @@ -48,230 +16,3 @@ def _validate_table_and_column_names(table_name, columns, metadata):
missing_columns = columns - set(metadata.tables[table_name].columns)
missing_columns = "', '".join(sorted(missing_columns))
raise PatternNotMetError(f"Table '{table_name}' is missing columns '{missing_columns}'.")


def cast_to_datetime64(value, datetime_format=None):
"""Cast a given value to a ``numpy.datetime64`` format.
Args:
value (pandas.Series, np.ndarray, list, or str):
Input data to convert to ``numpy.datetime64``.
datetime_format (str):
Datetime format of the `value`.
Return:
``numpy.datetime64`` value or values.
"""
if datetime_format:
datetime_format = datetime_format.replace('%-', '%')

if isinstance(value, str):
value = pd.to_datetime(value, format=datetime_format).to_datetime64()
elif isinstance(value, pd.Series):
value = value.astype('datetime64[ns]')
elif isinstance(value, (np.ndarray, list)):
value = np.array([
pd.to_datetime(item, format=datetime_format).to_datetime64()
if not pd.isna(item)
else pd.NaT.to_datetime64()
for item in value
])

return value


def get_nan_component_value(row):
"""Check for NaNs in a pandas row.
Outputs a concatenated string of the column names with NaNs.
Args:
row (pandas.Series):
A pandas row.
Returns:
A concatenated string of the column names with NaNs.
"""
columns_with_nans = []
for column, value in row.items():
if pd.isna(value):
columns_with_nans.append(column)

if columns_with_nans:
return ', '.join(columns_with_nans)
else:
return 'None'


def compute_nans_column(table_data, list_column_names):
"""Compute a categorical column to the table_data indicating where NaNs are.
Args:
table_data (pandas.DataFrame):
The table data.
list_column_names (list):
The list of column names to check for NaNs.
Returns:
A dict with the column name as key and the column indicating where NaNs are as value.
Empty dict if there are no NaNs.
"""
nan_column_name = '#'.join(list_column_names) + '.nan_component'
column = table_data[list_column_names].apply(get_nan_component_value, axis=1)
if not (column == 'None').all():
return pd.Series(column, name=nan_column_name)

return None


def revert_nans_columns(table_data, nan_column_name):
"""Reverts the NaNs in the table_data based on the categorical column.
Args:
table_data (pandas.DataFrame):
The table data.
nan_column (pandas.Series):
The categorical columns indicating where the NaNs are.
"""
combinations = table_data[nan_column_name].unique()
for combination in combinations:
if combination != 'None':
column_names = [column_name.strip() for column_name in combination.split(',')]
table_data.loc[table_data[nan_column_name] == combination, column_names] = np.nan

return table_data.drop(columns=nan_column_name)


def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format=None, dtype='O'):
"""Calculate the difference between two datetime columns.
When casting datetimes to float using ``astype``, NaT values are not automatically
converted to NaN values. This method calculates the difference between the high
and low column values, preserving missing values as NaNs.
Args:
high (numpy.ndarray):
The high column values.
low (numpy.ndarray):
The low column values.
high_datetime_format (str):
Datetime format of the `high` column.
low_datetime_format (str):
Datetime format of the `low` column.
Returns:
numpy.ndarray:
The difference between the high and low column values.
"""
if dtype == 'O':
low = cast_to_datetime64(low, low_datetime_format)
high = cast_to_datetime64(high, high_datetime_format)

if low_datetime_format != high_datetime_format:
low, high = match_datetime_precision(
low=low,
high=high,
low_datetime_format=low_datetime_format,
high_datetime_format=high_datetime_format,
)

diff_column = high - low
nan_mask = pd.isna(diff_column)
diff_column = diff_column.astype(np.float64)
diff_column[nan_mask] = np.nan
return diff_column


def format_datetime_array(datetime_array, target_format):
"""Format each element in a numpy datetime64 array to a specified string format.
Args:
datetime_array (np.ndarray):
Array of datetime64[ns] elements.
target_format (str):
The datetime format to cast each element to.
Returns:
np.ndarray: Array of formatted datetime strings.
"""
return np.array([
pd.to_datetime(date).strftime(target_format) if not pd.isna(date) else pd.NaT
for date in datetime_array
])


def downcast_datetime_to_lower_precision(data, target_format):
"""Convert a datetime string from a higher-precision format to a lower-precision format.
Args:
data (np.array):
The data to cast to the `target_format`.
target_format (str):
The datetime string to downcast.
Returns:
str: The datetime string in the lower precision format.
"""
downcasted_data = format_datetime_array(data, target_format)
return cast_to_datetime64(downcasted_data, target_format)


def get_datetime_format_precision(format_str):
"""Return the precision level of a datetime format string."""
# Find all format codes in the format string
found_formats = re.findall(r'%[A-Za-z]', format_str)
found_levels = (
PRECISION_LEVELS.get(found_format)
for found_format in found_formats
if found_format in PRECISION_LEVELS
)

return max(found_levels, default=0)


def get_lower_precision_format(primary_format, secondary_format):
"""Compare two datetime format strings and return the one with lower precision.
Args:
primary_format (str):
The first datetime format string to compare.
low_precision_format (str):
The second datetime format string to compare.
Returns:
str:
The datetime format string with the lower precision level.
"""
primary_level = get_datetime_format_precision(primary_format)
secondary_level = get_datetime_format_precision(secondary_format)
if primary_level >= secondary_level:
return secondary_format

return primary_format


def match_datetime_precision(low, high, low_datetime_format, high_datetime_format):
"""Match `low` or `high` datetime array to the lower precision format.
Args:
low (np.ndarray):
Array of datetime values for the low column.
high (np.ndarray):
Array of datetime values for the high column.
low_datetime_format (str):
The datetime format of the `low` column.
high_datetime_format (str):
The datetime format of the `high` column.
Returns:
Tuple[np.ndarray, np.ndarray]:
Adjusted `low` and `high` arrays where the higher precision format is
downcasted to the lower precision format.
"""
lower_precision_format = get_lower_precision_format(low_datetime_format, high_datetime_format)
if lower_precision_format == high_datetime_format:
low = downcast_datetime_to_lower_precision(low, lower_precision_format)
else:
high = downcast_datetime_to_lower_precision(high, lower_precision_format)

return low, high
15 changes: 11 additions & 4 deletions sdv/cag/inequality.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

from sdv._utils import _convert_to_timedelta, _create_unique_name
from sdv.cag._errors import PatternNotMetError
from sdv.cag._utils import (
_validate_table_and_column_names,
from sdv.cag._utils import _validate_table_and_column_names
from sdv.cag.base import BasePattern
from sdv.constraints.utils import (
cast_to_datetime64,
compute_nans_column,
get_datetime_diff,
match_datetime_precision,
revert_nans_columns,
)
from sdv.cag.base import BasePattern
from sdv.metadata import Metadata


Expand Down Expand Up @@ -173,7 +173,14 @@ def _get_updated_metadata(self, metadata):
return Metadata.load_from_dict(metadata)

def _fit(self, data, metadata):
"""Fit the pattern."""
"""Fit the pattern.
Args:
data (dict[str, pd.DataFrame]):
Table data.
metadata (Metadata):
Metadata.
"""
table_name = self._get_single_table_name(metadata)
table_data = data[table_name]
self._dtype = table_data[self._high_column_name].dtypes
Expand Down
3 changes: 1 addition & 2 deletions sdv/constraints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ def get_nan_component_value(row):

if columns_with_nans:
return ', '.join(columns_with_nans)
else:
return 'None'
return 'None'


def compute_nans_column(table_data, list_column_names):
Expand Down
Loading

0 comments on commit 74151e2

Please sign in to comment.