From ef5515f8c24b78279223bca17e0848935073e551 Mon Sep 17 00:00:00 2001 From: Rory McStay Date: Thu, 15 Aug 2024 11:36:34 +0000 Subject: [PATCH] Handle timestamp and nans in removing multi index failure cases #1469 Signed-off-by: Rory Signed-off-by: Rory McStay --- pandera/backends/pandas/base.py | 22 ++++- tests/core/test_schemas.py | 139 ++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 1 deletion(-) diff --git a/pandera/backends/pandas/base.py b/pandera/backends/pandas/base.py index d4ba5da7b..2dbff3ca1 100644 --- a/pandera/backends/pandas/base.py +++ b/pandera/backends/pandas/base.py @@ -4,6 +4,7 @@ from collections import defaultdict from typing import List, Optional, TypeVar, Union +import numpy as np import pandas as pd from pandera.api.base.checks import CheckResult @@ -28,6 +29,14 @@ SchemaWarning, ) + +_MULTIINDEX_HANDLED_TYPES = { + "Timestamp": pd.Timestamp, + "NaT": pd.NaT, + "nan": np.nan, +} + + FieldCheckObj = Union[pd.Series, pd.DataFrame] T = TypeVar( @@ -196,7 +205,18 @@ def drop_invalid_rows(self, check_obj, error_handler: ErrorHandler): if isinstance(check_obj.index, pd.MultiIndex): # MultiIndex values are saved on the error as strings so need to be cast back # to their original types - index_tuples = err.failure_cases["index"].apply(eval) + # fmt: off + index_tuples = ( + err.failure_cases["index"] + .astype(str) + .apply( + lambda i: eval(i, _MULTIINDEX_HANDLED_TYPES) # pylint: disable=eval-used + ) + ) + # fmt: on + # type check on a column of index. + if len(index_tuples) == 1 and index_tuples[0] is None: + continue index_values = pd.MultiIndex.from_tuples(index_tuples) mask = ~check_obj.index.isin(index_values) diff --git a/tests/core/test_schemas.py b/tests/core/test_schemas.py index e5fb6c60c..8b21804c1 100644 --- a/tests/core/test_schemas.py +++ b/tests/core/test_schemas.py @@ -2648,3 +2648,142 @@ def test_schema_column_default_handle_nans( df = pd.DataFrame({"column1": [input_value]}) schema.validate(df, inplace=True) assert df.iloc[0]["column1"] == default + + +@pytest.mark.parametrize( + "schema, obj, expected_obj, check_dtype", + [ + ( + DataFrameSchema( + columns={ + "temperature": Column(float, nullable=False), + }, + index=MultiIndex( + [ + Index(pd.Timestamp, name="timestamp"), + Index(str, name="city"), + ] + ), + drop_invalid_rows=True, + ), + pd.DataFrame( + { + "temperature": [ + 3.0, + 4.0, + 5.0, + 5.0, + np.nan, + 2.0, + ], + }, + index=pd.MultiIndex.from_tuples( + ( + (pd.Timestamp("2022-01-01"), "Paris"), + (pd.Timestamp("2023-01-01"), "Paris"), + (pd.Timestamp("2024-01-01"), "Paris"), + (pd.Timestamp("2022-01-01"), "Oslo"), + (pd.Timestamp("2023-01-01"), "Oslo"), + (pd.Timestamp("2024-01-01"), "Oslo"), + ), + names=["timestamp", "city"], + ), + ), + pd.DataFrame( + { + "temperature": [3.0, 4.0, 5.0, 5.0, 2.0], + }, + index=pd.MultiIndex.from_tuples( + ( + (pd.Timestamp("2022-01-01"), "Paris"), + (pd.Timestamp("2023-01-01"), "Paris"), + (pd.Timestamp("2024-01-01"), "Paris"), + (pd.Timestamp("2022-01-01"), "Oslo"), + (pd.Timestamp("2024-01-01"), "Oslo"), + ), + names=["timestamp", "city"], + ), + ), + True, + ), + ( + DataFrameSchema( + columns={ + "temperature": Column(float, nullable=False), + }, + index=MultiIndex( + [ + Index(pd.Timestamp, name="timestamp"), + Index(str, name="city"), + ] + ), + drop_invalid_rows=True, + ), + pd.DataFrame( + { + "temperature": [ + 3.0, + 4.0, + 5.0, + -1.0, + np.nan, + -2.0, + 4.0, + 5.0, + 2.0, + ], + }, + index=pd.MultiIndex.from_tuples( + ( + (pd.Timestamp("2022-01-01"), "Paris"), + (pd.Timestamp("2023-01-01"), "Paris"), + (pd.Timestamp("2024-01-01"), "Paris"), + (pd.Timestamp("2022-01-01"), "Oslo"), + (pd.Timestamp("2023-01-01"), "Oslo"), + (pd.Timestamp("2024-01-01"), "Oslo"), + ( + pd.Timestamp("2024-01-01", tz="Europe/London"), + "London", + ), + (pd.Timestamp(pd.NaT), "Frankfurt"), + (pd.Timestamp("2024-01-01"), 6), + ), + names=["timestamp", "city"], + ), + ), + pd.DataFrame( + { + "temperature": [3.0, 4.0, 5.0, -1.0, -2.0, 4], + }, + index=pd.MultiIndex.from_tuples( + ( + (pd.Timestamp("2022-01-01"), "Paris"), + (pd.Timestamp("2023-01-01"), "Paris"), + (pd.Timestamp("2024-01-01"), "Paris"), + (pd.Timestamp("2022-01-01"), "Oslo"), + (pd.Timestamp("2024-01-01"), "Oslo"), + ( + pd.Timestamp("2024-01-01", tz="Europe/London"), + "London", + ), + ), + names=["timestamp", "city"], + ), + ), + False, + ), + ], +) +def test_drop_invalid_for_multi_index_with_datetime( + schema, obj, expected_obj, check_dtype +): + """Test drop_invalid_rows works as expected on multi-index dataframes""" + actual_obj = schema.validate(obj, lazy=True) + + # the datatype of the index is not casted, In this cases its an object + pd.testing.assert_frame_equal( + actual_obj, + expected_obj, + check_dtype=check_dtype, + check_index_type=check_dtype, + )