From bc690e686cb90ca23ac27089fe907b604246916f Mon Sep 17 00:00:00 2001 From: Roman Bredehoft Date: Fri, 21 Jun 2024 11:18:59 +0200 Subject: [PATCH] chore: fix pandas schema flaky test --- tests/pandas/test_pandas.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/tests/pandas/test_pandas.py b/tests/pandas/test_pandas.py index bb2d5f713..78af798aa 100644 --- a/tests/pandas/test_pandas.py +++ b/tests/pandas/test_pandas.py @@ -24,6 +24,7 @@ def generate_pandas_dataframe( n_features: int = 1, index_name: Optional[str] = None, indexes: Optional[Union[int, List]] = None, + str_values: Optional[List[str]] = None, index_position: int = 0, include_nan: bool = True, float_min: float = -10.0, @@ -46,6 +47,7 @@ def generate_pandas_dataframe( index_name (Optional[str]): The index's name. Default to None ("index"). indexes (Optional[Union[int, List]]): Custom indexes to consider. Default to None (5 rows, indexed from 1 to 5). + str_values (Optional[List[str]]): The list of string values to consider. Default to None. index_position (int): The index's column position in the data-frame. Default to 0. include_nan (bool): If NaN values should be put in the data-frame. If True, they are inserted in the first row. Default to True. @@ -75,6 +77,10 @@ def generate_pandas_dataframe( if isinstance(indexes, int): indexes = list(range(1, indexes + 1)) + assert str_values is None or len(str_values) == len( + indexes + ), "Parameter 'str_values' must either be None or a list of length equal to 'indexes'." + if index_name is None: index_name = "index" @@ -102,11 +108,16 @@ def generate_pandas_dataframe( # Add a column with string values (including NaN or not) if dtype in ["str", "mixed"]: - str_values = ["apple", "orange", "watermelon", "cherry", "banana"] + str_values_default = ["apple", "orange", "watermelon", "cherry", "banana"] for i in range(1, n_features + 1): column_name = f"{feat_name}_str_{i}" - columns[column_name] = list(numpy.random.choice(str_values, size=(len(indexes),))) + rand_str_values = ( + list(numpy.random.choice(str_values_default, size=(len(indexes),))) + if str_values is None + else str_values + ) + columns[column_name] = rand_str_values if include_nan: columns[column_name][0] = numpy.nan @@ -694,8 +705,17 @@ def check_invalid_schema_values(): client = ClientEngine(keys_path=keys_path) + # Fix the string values to consider in order to avoid flaky tests (one of the checks requires to + # have a data-frame with at least 2 unique string values) + str_values = ["apple", "orange", "watermelon", "cherry", "banana"] + pandas_df = generate_pandas_dataframe( - feat_name=feat_name, index_name=selected_column, float_min=float_min, float_max=float_max + feat_name=feat_name, + index_name=selected_column, + indexes=len(str_values), + str_values=str_values, + float_min=float_min, + float_max=float_max, ) schema_int_column = {f"{feat_name}_int_1": {None: None}}