From 74151e22193240f4de0b66c04b9e54eb1e544623 Mon Sep 17 00:00:00 2001 From: Felipe Date: Thu, 6 Mar 2025 08:25:50 -0800 Subject: [PATCH] Update tests --- sdv/cag/_utils.py | 259 -------------- sdv/cag/inequality.py | 15 +- sdv/constraints/utils.py | 3 +- tests/integration/cag/test_inequality.py | 43 +-- tests/unit/cag/test__utils.py | 409 +---------------------- tests/unit/cag/test_inequality.py | 15 +- 6 files changed, 36 insertions(+), 708 deletions(-) diff --git a/sdv/cag/_utils.py b/sdv/cag/_utils.py index 7959fbcca..be4b709a2 100644 --- a/sdv/cag/_utils.py +++ b/sdv/cag/_utils.py @@ -1,37 +1,5 @@ -import re - -import numpy as np -import pandas as pd - from sdv.cag._errors import PatternNotMetError -PRECISION_LEVELS = { - '%Y': 1, # Year - '%y': 1, # Year without century (same precision as %Y) - '%B': 2, # Full month name - '%b': 2, # Abbreviated month name - '%m': 2, # Month as a number - '%d': 3, # Day of the month - '%j': 3, # Day of the year - '%U': 3, # Week number (Sunday-starting) - '%W': 3, # Week number (Monday-starting) - '%A': 3, # Full weekday name - '%a': 3, # Abbreviated weekday name - '%w': 3, # Weekday as a decimal - '%H': 4, # Hour (24-hour clock) - '%I': 4, # Hour (12-hour clock) - '%M': 5, # Minute - '%S': 6, # Second - '%f': 7, # Microsecond - # Formats that don't add precision - '%p': 0, # AM/PM - '%z': 0, # UTC offset - '%Z': 0, # Time zone name - '%c': 0, # Locale-based date/time - '%x': 0, # Locale-based date - '%X': 0, # Locale-based time -} - def _validate_table_and_column_names(table_name, columns, metadata): """Validate the table and column names for the pattern.""" @@ -48,230 +16,3 @@ def _validate_table_and_column_names(table_name, columns, metadata): missing_columns = columns - set(metadata.tables[table_name].columns) missing_columns = "', '".join(sorted(missing_columns)) raise PatternNotMetError(f"Table '{table_name}' is missing columns '{missing_columns}'.") - - -def cast_to_datetime64(value, datetime_format=None): - """Cast a given value to a ``numpy.datetime64`` format. - - Args: - value (pandas.Series, np.ndarray, list, or str): - Input data to convert to ``numpy.datetime64``. - datetime_format (str): - Datetime format of the `value`. - - Return: - ``numpy.datetime64`` value or values. - """ - if datetime_format: - datetime_format = datetime_format.replace('%-', '%') - - if isinstance(value, str): - value = pd.to_datetime(value, format=datetime_format).to_datetime64() - elif isinstance(value, pd.Series): - value = value.astype('datetime64[ns]') - elif isinstance(value, (np.ndarray, list)): - value = np.array([ - pd.to_datetime(item, format=datetime_format).to_datetime64() - if not pd.isna(item) - else pd.NaT.to_datetime64() - for item in value - ]) - - return value - - -def get_nan_component_value(row): - """Check for NaNs in a pandas row. - - Outputs a concatenated string of the column names with NaNs. - - Args: - row (pandas.Series): - A pandas row. - - Returns: - A concatenated string of the column names with NaNs. - """ - columns_with_nans = [] - for column, value in row.items(): - if pd.isna(value): - columns_with_nans.append(column) - - if columns_with_nans: - return ', '.join(columns_with_nans) - else: - return 'None' - - -def compute_nans_column(table_data, list_column_names): - """Compute a categorical column to the table_data indicating where NaNs are. - - Args: - table_data (pandas.DataFrame): - The table data. - list_column_names (list): - The list of column names to check for NaNs. - - Returns: - A dict with the column name as key and the column indicating where NaNs are as value. - Empty dict if there are no NaNs. - """ - nan_column_name = '#'.join(list_column_names) + '.nan_component' - column = table_data[list_column_names].apply(get_nan_component_value, axis=1) - if not (column == 'None').all(): - return pd.Series(column, name=nan_column_name) - - return None - - -def revert_nans_columns(table_data, nan_column_name): - """Reverts the NaNs in the table_data based on the categorical column. - - Args: - table_data (pandas.DataFrame): - The table data. - nan_column (pandas.Series): - The categorical columns indicating where the NaNs are. - """ - combinations = table_data[nan_column_name].unique() - for combination in combinations: - if combination != 'None': - column_names = [column_name.strip() for column_name in combination.split(',')] - table_data.loc[table_data[nan_column_name] == combination, column_names] = np.nan - - return table_data.drop(columns=nan_column_name) - - -def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format=None, dtype='O'): - """Calculate the difference between two datetime columns. - - When casting datetimes to float using ``astype``, NaT values are not automatically - converted to NaN values. This method calculates the difference between the high - and low column values, preserving missing values as NaNs. - - Args: - high (numpy.ndarray): - The high column values. - low (numpy.ndarray): - The low column values. - high_datetime_format (str): - Datetime format of the `high` column. - low_datetime_format (str): - Datetime format of the `low` column. - - Returns: - numpy.ndarray: - The difference between the high and low column values. - """ - if dtype == 'O': - low = cast_to_datetime64(low, low_datetime_format) - high = cast_to_datetime64(high, high_datetime_format) - - if low_datetime_format != high_datetime_format: - low, high = match_datetime_precision( - low=low, - high=high, - low_datetime_format=low_datetime_format, - high_datetime_format=high_datetime_format, - ) - - diff_column = high - low - nan_mask = pd.isna(diff_column) - diff_column = diff_column.astype(np.float64) - diff_column[nan_mask] = np.nan - return diff_column - - -def format_datetime_array(datetime_array, target_format): - """Format each element in a numpy datetime64 array to a specified string format. - - Args: - datetime_array (np.ndarray): - Array of datetime64[ns] elements. - target_format (str): - The datetime format to cast each element to. - - Returns: - np.ndarray: Array of formatted datetime strings. - """ - return np.array([ - pd.to_datetime(date).strftime(target_format) if not pd.isna(date) else pd.NaT - for date in datetime_array - ]) - - -def downcast_datetime_to_lower_precision(data, target_format): - """Convert a datetime string from a higher-precision format to a lower-precision format. - - Args: - data (np.array): - The data to cast to the `target_format`. - target_format (str): - The datetime string to downcast. - - Returns: - str: The datetime string in the lower precision format. - """ - downcasted_data = format_datetime_array(data, target_format) - return cast_to_datetime64(downcasted_data, target_format) - - -def get_datetime_format_precision(format_str): - """Return the precision level of a datetime format string.""" - # Find all format codes in the format string - found_formats = re.findall(r'%[A-Za-z]', format_str) - found_levels = ( - PRECISION_LEVELS.get(found_format) - for found_format in found_formats - if found_format in PRECISION_LEVELS - ) - - return max(found_levels, default=0) - - -def get_lower_precision_format(primary_format, secondary_format): - """Compare two datetime format strings and return the one with lower precision. - - Args: - primary_format (str): - The first datetime format string to compare. - low_precision_format (str): - The second datetime format string to compare. - - Returns: - str: - The datetime format string with the lower precision level. - """ - primary_level = get_datetime_format_precision(primary_format) - secondary_level = get_datetime_format_precision(secondary_format) - if primary_level >= secondary_level: - return secondary_format - - return primary_format - - -def match_datetime_precision(low, high, low_datetime_format, high_datetime_format): - """Match `low` or `high` datetime array to the lower precision format. - - Args: - low (np.ndarray): - Array of datetime values for the low column. - high (np.ndarray): - Array of datetime values for the high column. - low_datetime_format (str): - The datetime format of the `low` column. - high_datetime_format (str): - The datetime format of the `high` column. - - Returns: - Tuple[np.ndarray, np.ndarray]: - Adjusted `low` and `high` arrays where the higher precision format is - downcasted to the lower precision format. - """ - lower_precision_format = get_lower_precision_format(low_datetime_format, high_datetime_format) - if lower_precision_format == high_datetime_format: - low = downcast_datetime_to_lower_precision(low, lower_precision_format) - else: - high = downcast_datetime_to_lower_precision(high, lower_precision_format) - - return low, high diff --git a/sdv/cag/inequality.py b/sdv/cag/inequality.py index b086e08cb..be29befc6 100644 --- a/sdv/cag/inequality.py +++ b/sdv/cag/inequality.py @@ -5,15 +5,15 @@ from sdv._utils import _convert_to_timedelta, _create_unique_name from sdv.cag._errors import PatternNotMetError -from sdv.cag._utils import ( - _validate_table_and_column_names, +from sdv.cag._utils import _validate_table_and_column_names +from sdv.cag.base import BasePattern +from sdv.constraints.utils import ( cast_to_datetime64, compute_nans_column, get_datetime_diff, match_datetime_precision, revert_nans_columns, ) -from sdv.cag.base import BasePattern from sdv.metadata import Metadata @@ -173,7 +173,14 @@ def _get_updated_metadata(self, metadata): return Metadata.load_from_dict(metadata) def _fit(self, data, metadata): - """Fit the pattern.""" + """Fit the pattern. + + Args: + data (dict[str, pd.DataFrame]): + Table data. + metadata (Metadata): + Metadata. + """ table_name = self._get_single_table_name(metadata) table_data = data[table_name] self._dtype = table_data[self._high_column_name].dtypes diff --git a/sdv/constraints/utils.py b/sdv/constraints/utils.py index f7275a49d..95598adcb 100644 --- a/sdv/constraints/utils.py +++ b/sdv/constraints/utils.py @@ -159,8 +159,7 @@ def get_nan_component_value(row): if columns_with_nans: return ', '.join(columns_with_nans) - else: - return 'None' + return 'None' def compute_nans_column(table_data, list_column_names): diff --git a/tests/integration/cag/test_inequality.py b/tests/integration/cag/test_inequality.py index 4e65211a5..8e4101f11 100644 --- a/tests/integration/cag/test_inequality.py +++ b/tests/integration/cag/test_inequality.py @@ -5,9 +5,21 @@ from sdv.cag import Inequality from sdv.cag._errors import PatternNotMetError from sdv.metadata import Metadata +from sdv.multi_table.hma import HMASynthesizer from sdv.single_table import GaussianCopulaSynthesizer +def run_pattern(pattern, data, metadata): + """Run a pattern.""" + pattern.validate(data, metadata) + updated_metadata = pattern.get_updated_metadata(metadata) + pattern.fit(data, metadata) + transformed = pattern.transform(data) + reverse_transformed = pattern.reverse_transform(transformed) + + return updated_metadata, transformed, reverse_transformed + + def test_inequality_pattern_integers(): """Test that Inequality pattern works with integer columns.""" # Setup @@ -27,11 +39,7 @@ def test_inequality_pattern_integers(): ) # Run - pattern.validate(data, metadata) - updated_metadata = pattern.get_updated_metadata(metadata) - pattern.fit(data, metadata) - transformed = pattern.transform(data) - reverse_transformed = pattern.reverse_transform(transformed) + updated_metadata, transformed, reverse_transformed = run_pattern(pattern, data, metadata) # Assert expected_updated_metadata = Metadata.load_from_dict({ @@ -65,11 +73,7 @@ def test_inequality_pattern_with_nans(): ) # Run - pattern.validate(data, metadata) - updated_metadata = pattern.get_updated_metadata(metadata) - pattern.fit(data, metadata) - transformed = pattern.transform(data) - reverse_transformed = pattern.reverse_transform(transformed) + updated_metadata, transformed, reverse_transformed = run_pattern(pattern, data, metadata) # Assert expected_updated_metadata = Metadata.load_from_dict({ @@ -121,11 +125,7 @@ def test_inequality_pattern_datetime(): ) # Run - pattern.validate(data, metadata) - updated_metadata = pattern.get_updated_metadata(metadata) - pattern.fit(data, metadata) - transformed = pattern.transform(data) - reverse_transformed = pattern.reverse_transform(transformed) + updated_metadata, transformed, reverse_transformed = run_pattern(pattern, data, metadata) # Assert expected_updated_metadata = Metadata.load_from_dict({ @@ -176,11 +176,7 @@ def test_inequality_pattern_datetime_nans(): ) # Run - pattern.validate(data, metadata) - updated_metadata = pattern.get_updated_metadata(metadata) - pattern.fit(data, metadata) - transformed = pattern.transform(data) - reverse_transformed = pattern.reverse_transform(transformed) + updated_metadata, transformed, reverse_transformed = run_pattern(pattern, data, metadata) # Assert expected_updated_metadata = Metadata.load_from_dict({ @@ -233,11 +229,7 @@ def test_inequality_pattern_with_multi_table(): ) # Run - pattern.validate(data, metadata) - updated_metadata = pattern.get_updated_metadata(metadata) - pattern.fit(data, metadata) - transformed = pattern.transform(data) - reverse_transformed = pattern.reverse_transform(transformed) + updated_metadata, transformed, reverse_transformed = run_pattern(pattern, data, metadata) # Assert expected_updated_metadata = Metadata.load_from_dict({ @@ -300,7 +292,6 @@ def test_inequality_with_timestamp_and_date(): } }) synthesizer = GaussianCopulaSynthesizer(metadata) - pattern = Inequality( low_column_name='SUBMISSION_TIMESTAMP', high_column_name='DUE_DATE', diff --git a/tests/unit/cag/test__utils.py b/tests/unit/cag/test__utils.py index 25915e669..532f8d97b 100644 --- a/tests/unit/cag/test__utils.py +++ b/tests/unit/cag/test__utils.py @@ -1,25 +1,13 @@ """CAG _utils unit tests.""" import re -from unittest.mock import Mock, patch +from unittest.mock import Mock -import numpy as np -import pandas as pd import pytest from sdv.cag._errors import PatternNotMetError from sdv.cag._utils import ( _validate_table_and_column_names, - cast_to_datetime64, - compute_nans_column, - downcast_datetime_to_lower_precision, - format_datetime_array, - get_datetime_diff, - get_datetime_format_precision, - get_lower_precision_format, - get_nan_component_value, - match_datetime_precision, - revert_nans_columns, ) @@ -67,398 +55,3 @@ def test__validate_table_and_column_names_single_table(): # Assert metadata._get_single_table_name.assert_called_once() - - -def test_cast_to_datetime64(): - """Test the ``cast_to_datetime64`` function. - - Setup: - - String value representing a datetime - - List value with a ``np.nan`` and string values. - - pd.Series with datetime values. - Output: - - A single np.datetime64 - - A list of np.datetime64 - - A series of np.datetime64 - """ - # Setup - string_value = '2021-02-02' - list_value = [None, np.nan, '2021-02-02'] - series_value = pd.Series(['2021-02-02', None, pd.NaT]) - - # Run - string_out = cast_to_datetime64(string_value) - list_out = cast_to_datetime64(list_value) - series_out = cast_to_datetime64(series_value) - - # Assert - expected_string_output = np.datetime64('2021-02-02') - expected_series_output = pd.Series([ - np.datetime64('2021-02-02'), - np.datetime64('NaT'), - np.datetime64('NaT'), - ]) - expected_list_output = np.array( - [np.datetime64('NaT'), np.datetime64('NaT'), '2021-02-02'], dtype='datetime64[ns]' - ) - np.testing.assert_array_equal(expected_list_output, list_out) - pd.testing.assert_series_equal(expected_series_output, series_out) - assert expected_string_output == string_out - - -def test_cast_to_datetime64_datetime_format(): - """Test it when `datetime_format` is passed.""" - # Setup - string_value = '2021-02-02' - list_value = [None, np.nan, '2021-02-02'] - series_value = pd.Series(['2021-02-02', None, pd.NaT]) - - # Run - string_out = cast_to_datetime64(string_value, datetime_format='%Y-%m-%d') - list_out = cast_to_datetime64(list_value, datetime_format='%Y-%m-%d') - series_out = cast_to_datetime64(series_value, datetime_format='%Y-%m-%d') - - # Assert - expected_string_output = np.datetime64('2021-02-02') - expected_series_output = pd.Series([ - np.datetime64('2021-02-02'), - np.datetime64('NaT'), - np.datetime64('NaT'), - ]) - expected_list_output = np.array( - [np.datetime64('NaT'), np.datetime64('NaT'), '2021-02-02'], dtype='datetime64[ns]' - ) - np.testing.assert_array_equal(expected_list_output, list_out) - pd.testing.assert_series_equal(expected_series_output, series_out) - assert expected_string_output == string_out - - -def test_get_nan_component_value(): - """Test the ``get_nan_component_value`` method.""" - # Setup - row = pd.Series([np.nan, 2, np.nan, 4], index=['a', 'b', 'c', 'd']) - - # Run - result = get_nan_component_value(row) - - # Assert - assert result == 'a, c' - - -def test_compute_nans_columns(): - """Test the ``compute_nans_columns`` method.""" - # Setup - data = pd.DataFrame({ - 'a': [1, np.nan, 3, np.nan], - 'b': [np.nan, 2, 3, np.nan], - 'c': [1, np.nan, 3, np.nan], - }) - - # Run - output = compute_nans_column(data, ['a', 'b', 'c']) - expected_output = pd.Series(['b', 'a, c', 'None', 'a, b, c'], name='a#b#c.nan_component') - - # Assert - pd.testing.assert_series_equal(output, expected_output) - - -def test_compute_nans_columns_without_nan(): - """Test the ``compute_nans_columns`` method when there are no nans.""" - # Setup - data = pd.DataFrame({'a': [1, 2, 3, 2], 'b': [2.5, 2, 3, 2.5], 'c': [1, 2, 3, 2]}) - - # Run - output = compute_nans_column(data, ['a', 'b', 'c']) - - # Assert - assert output is None - - -def test_revert_nans_columns(): - """Test the ``revert_nans_columns`` method.""" - # Setup - data = pd.DataFrame({ - 'a': [1, 2, 3, 2], - 'b': [2.5, 2, 3, 2.5], - 'c': [1, 2, 3, 2], - 'a#b#c.nan_component': ['b', 'a, c', 'None', 'a, b, c'], - }) - nan_column_name = 'a#b#c.nan_component' - - # Run - result = revert_nans_columns(data, nan_column_name) - - expected_data = pd.DataFrame({ - 'a': [1, np.nan, 3, np.nan], - 'b': [np.nan, 2, 3, np.nan], - 'c': [1, np.nan, 3, np.nan], - }) - - # Assert - pd.testing.assert_frame_equal(result, expected_data) - - -def test_get_datetime_diff(): - """Test the ``_get_datetime_diff`` method. - - The method is expected to compute the difference between the high and low - datetime columns, treating missing values as NaN. - """ - # Setup - high = pd.Series(['2022-02-02', '', '2023-01-02']).to_numpy() - low = pd.Series(['2022-02-01', '2022-02-02', '2023-01-01']).to_numpy() - expected = np.array([8.64e13, np.nan, 8.64e13]) - - # Run - diff = get_datetime_diff(high, low, dtype='O') - - # Assert - assert np.array_equal(expected, diff, equal_nan=True) - - -def test_get_datetime_diff_with_format_precision_missmatch(): - """Test `get_datetime_diff` with miss matching datetime formats.""" - # Setup - high = np.array(['2024-11-13 12:00:00.123', '2024-11-13 13:00:00.456'], dtype='O') - low = np.array(['2024-11-13 12:00:00', '2024-11-13 13:00:00'], dtype='O') - high_format = '%Y-%m-%d %H:%M:%S.%f' - low_format = '%Y-%m-%d %H:%M:%S' - expected_diff = np.array([0.0, 0.0], dtype=np.float64) - - # Run - result = get_datetime_diff( - high, low, high_datetime_format=high_format, low_datetime_format=low_format - ) - - # Assert - np.testing.assert_array_almost_equal(result, expected_diff) - - -def test_get_datetime_format_precision_seconds(): - """Test `get_datetime_format_precision` with second-level precision.""" - # Setup - format_str = '%Y-%m-%d %H:%M:%S' - expected_precision = 6 - - # Run - result = get_datetime_format_precision(format_str) - - # Assert - assert result == expected_precision - - -def test_get_datetime_format_precision_microseconds(): - """Test `get_datetime_format_precision` with microsecond-level precision.""" - # Setup - format_str = '%Y-%m-%d %H:%M:%S.%f' - expected_precision = 7 - - # Run - result = get_datetime_format_precision(format_str) - - # Assert - assert result == expected_precision - - -def test_get_datetime_format_precision_minutes(): - """Test `get_datetime_format_precision` with minute-level precision.""" - # Setup - format_str = '%Y-%m-%d %H:%M' - expected_precision = 5 - - # Run - result = get_datetime_format_precision(format_str) - - # Assert - assert result == expected_precision - - -def test_get_datetime_format_precision_days(): - """Test `get_datetime_format_precision` with day-level precision.""" - # Setup - format_str = '%Y-%m-%d' - expected_precision = 3 - - # Run - result = get_datetime_format_precision(format_str) - - # Assert - assert result == expected_precision - - -def test_get_datetime_format_precision_no_precision(): - """Test `get_datetime_format_precision` with no precision format.""" - # Setup - format_str = '%Y' - expected_precision = 1 - - # Run - result = get_datetime_format_precision(format_str) - - # Assert - assert result == expected_precision - - -def test_get_datetime_format_precision_mixed_format_higher_precision(): - """Test `get_datetime_format_precision` with mixed higher-precision format.""" - # Setup - format_str = '%Y-%m-%d %H:%M:%S.%f %z' - expected_precision = 7 - - # Run - result = get_datetime_format_precision(format_str) - - # Assert - assert result == expected_precision - - -def test_get_lower_precision_format_with_different_precision(): - """Test `get_lower_precision_format` with different precision levels.""" - # Setup - primary_format = '%Y-%m-%d %H:%M:%S' - secondary_format = '%Y-%m-%d %H:%M:%S.%f' - - # Run - result = get_lower_precision_format(primary_format, secondary_format) - - # Assert - assert result == primary_format - - -def test_get_lower_precision_format_with_equal_precision(): - """Test `get_lower_precision_format` when both formats have the same precision.""" - # Setup - primary_format = '%Y-%m-%d %H:%M:%S' - secondary_format = '%Y-%m-%d %H:%M:%S' - - # Run - result = get_lower_precision_format(primary_format, secondary_format) - - # Assert - assert result == secondary_format == primary_format - - -def test_get_lower_precision_format_with_date_only(): - """Test `get_lower_precision_format` with date-only formats.""" - # Setup - primary_format = '%Y-%m-%d' - secondary_format = '%Y-%m' - - # Run - result = get_lower_precision_format(primary_format, secondary_format) - - # Assert - assert result == secondary_format - - -def test_get_lower_precision_format_with_week_and_day_formats(): - """Test `get_lower_precision_format` with week and day level formats.""" - # Setup - primary_format = '%Y-%W' - secondary_format = '%Y-%m-%d' - - # Run - result = get_lower_precision_format(primary_format, secondary_format) - - # Assert - assert result == secondary_format - - -def test_downcast_datetime_to_lower_precision(): - """Test `downcast_datetime_to_lower_precision` to ensure datetime downcasting.""" - # Setup - data = np.array( - ['2024-11-13 12:30:45.123456789', '2024-11-13 13:45:30.987654321'], dtype='datetime64[ns]' - ) - target_format = '%Y-%m-%d %H:%M:%S' - expected_result = np.array(['2024-11-13 12:30:45', '2024-11-13 13:45:30'], dtype='O') - - # Run - result = downcast_datetime_to_lower_precision(data, target_format) - - # Assert - np.testing.assert_array_equal(result, cast_to_datetime64(expected_result)) - - -def test_downcast_datetime_to_lower_precision_to_day(): - """Test `downcast_datetime_to_lower_precision` to downcast datetime to day precision.""" - # Setup - data = np.array( - ['2024-11-13 12:30:45.123456789', '2024-11-14 13:45:30.987654321'], dtype='datetime64[ns]' - ) - target_format = '%Y-%m-%d' # Downcasting to day precision - expected_result = np.array(['2024-11-13', '2024-11-14'], dtype='O') - - # Run - result = downcast_datetime_to_lower_precision(data, target_format) - - # Assert - np.testing.assert_array_equal(result, cast_to_datetime64(expected_result)) - - -def test_format_datetime_array_with_lower_precision_format(): - """Test `format_datetime_array` formatting datetime array to a lower-precision format.""" - # Setup - datetime_array = np.array( - ['2024-11-13 12:30:45.123456789', '2024-11-13 13:45:30.987654321'], dtype='datetime64[ns]' - ) - target_format = '%Y-%m-%d %H:%M:%S' - expected_result = np.array(['2024-11-13 12:30:45', '2024-11-13 13:45:30'], dtype='O') - - # Run - result = format_datetime_array(datetime_array, target_format) - - # Assert - np.testing.assert_array_equal(result, expected_result) - - -@patch('sdv.cag._utils.downcast_datetime_to_lower_precision') -def test_match_datetime_precision_low_has_higher_precision(mock_downcast): - """Test `match_datetime_precision` when `low` has higher precision than `high`. - - This test checks that if the `low` array has a more precise format than `high`, - `low` is downcasted to match the `high` format. - """ - # Setup - low = np.array(['2024-11-13 10:34:45.123456', '2024-11-14 12:20:10.654321'], dtype='O') - high = np.array(['2024-11-13 10:34:45', '2024-11-14 12:20:10'], dtype='O') - low_format = '%Y-%m-%d %H:%M:%S.%f' - high_format = '%Y-%m-%d %H:%M:%S' - expected_low = np.array(['2024-11-13 10:34:45', '2024-11-14 12:20:10'], dtype='O') - - # Set the return value of the mock to simulate downcasting - mock_downcast.return_value = expected_low - - # Run - result_low, result_high = match_datetime_precision(low, high, low_format, high_format) - - # Assert - mock_downcast.assert_called_once_with(low, high_format) - np.testing.assert_array_equal(result_low, expected_low) - np.testing.assert_array_equal(result_high, high) - - -@patch('sdv.cag._utils.downcast_datetime_to_lower_precision') -def test_match_datetime_precision_high_has_higher_precision(mock_downcast): - """Test `match_datetime_precision` when `high` has higher precision than `low`. - - This test checks that if the `high` array has a more precise format than `low`, - `high` is downcasted to match the `low` format. - """ - # Setup - low = np.array(['2024-11-13 10:34:45', '2024-11-14 12:20:10'], dtype='O') - high = np.array(['2024-11-13 10:34:45.123456', '2024-11-14 12:20:10.654321'], dtype='O') - low_format = '%Y-%m-%d %H:%M:%S' - high_format = '%Y-%m-%d %H:%M:%S.%f' - expected_high = np.array(['2024-11-13 10:34:45', '2024-11-14 12:20:10'], dtype='O') - - # Set the return value of the mock to simulate downcasting - mock_downcast.return_value = expected_high - - # Run - result_low, result_high = match_datetime_precision(low, high, low_format, high_format) - - # Assert - mock_downcast.assert_called_once_with(high, low_format) - np.testing.assert_array_equal(result_low, low) - np.testing.assert_array_equal(result_high, expected_high) diff --git a/tests/unit/cag/test_inequality.py b/tests/unit/cag/test_inequality.py index addb17b16..a0acf74db 100644 --- a/tests/unit/cag/test_inequality.py +++ b/tests/unit/cag/test_inequality.py @@ -14,17 +14,13 @@ class TestInequality: - def test___init___incorrect_low_column_name(self): - """Test it raises an error if low_column_name is not a string.""" + def test___init___incorrect_column_name(self): + """Test it raises an error if column_name is not a string.""" # Run and Assert err_msg = '`low_column_name` and `high_column_name` must be strings.' with pytest.raises(ValueError, match=err_msg): Inequality(low_column_name=1, high_column_name='b') - def test___init___incorrect_high_column_name(self): - """Test it raises an error if high_column_name is not a string.""" - # Run and Assert - err_msg = '`low_column_name` and `high_column_name` must be strings.' with pytest.raises(ValueError, match=err_msg): Inequality(low_column_name='a', high_column_name=1) @@ -470,10 +466,11 @@ def test__fit(self): assert instance._low_datetime_format == '%y %m, %d' assert instance._high_datetime_format == '%y %m %d' - def test__fit_numerical(self): + @pytest.mark.parametrize('dtype', ['Float64', 'Float32', 'Int64', 'Int32', 'Int16', 'Int8']) + def test__fit_numerical(self, dtype): """Test it for numerical columns.""" # Setup - table_data = {'table': pd.DataFrame({'a': [1, 2, 4], 'b': [4.0, 5.0, 6.0]})} + table_data = {'table': pd.DataFrame({'a': [1, 2, 4], 'b': [4, 5, 6]}, dtype=dtype)} metadata = Metadata.load_from_dict({ 'tables': { 'table': { @@ -490,7 +487,7 @@ def test__fit_numerical(self): instance._fit(table_data, metadata) # Assert - assert instance._dtype == np.dtype('float') + assert instance._dtype == dtype assert instance._is_datetime is False assert instance._low_datetime_format is None assert instance._high_datetime_format is None