Skip to content

Commit

Permalink
chore: fix pandas schema flaky test
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Jun 21, 2024
1 parent 5269297 commit bc690e6
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions tests/pandas/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def generate_pandas_dataframe(
n_features: int = 1,
index_name: Optional[str] = None,
indexes: Optional[Union[int, List]] = None,
str_values: Optional[List[str]] = None,
index_position: int = 0,
include_nan: bool = True,
float_min: float = -10.0,
Expand All @@ -46,6 +47,7 @@ def generate_pandas_dataframe(
index_name (Optional[str]): The index's name. Default to None ("index").
indexes (Optional[Union[int, List]]): Custom indexes to consider. Default to None (5 rows,
indexed from 1 to 5).
str_values (Optional[List[str]]): The list of string values to consider. Default to None.
index_position (int): The index's column position in the data-frame. Default to 0.
include_nan (bool): If NaN values should be put in the data-frame. If True, they are
inserted in the first row. Default to True.
Expand Down Expand Up @@ -75,6 +77,10 @@ def generate_pandas_dataframe(
if isinstance(indexes, int):
indexes = list(range(1, indexes + 1))

assert str_values is None or len(str_values) == len(
indexes
), "Parameter 'str_values' must either be None or a list of length equal to 'indexes'."

if index_name is None:
index_name = "index"

Expand Down Expand Up @@ -102,11 +108,16 @@ def generate_pandas_dataframe(

# Add a column with string values (including NaN or not)
if dtype in ["str", "mixed"]:
str_values = ["apple", "orange", "watermelon", "cherry", "banana"]
str_values_default = ["apple", "orange", "watermelon", "cherry", "banana"]

for i in range(1, n_features + 1):
column_name = f"{feat_name}_str_{i}"
columns[column_name] = list(numpy.random.choice(str_values, size=(len(indexes),)))
rand_str_values = (
list(numpy.random.choice(str_values_default, size=(len(indexes),)))
if str_values is None
else str_values
)
columns[column_name] = rand_str_values

if include_nan:
columns[column_name][0] = numpy.nan
Expand Down Expand Up @@ -694,8 +705,17 @@ def check_invalid_schema_values():

client = ClientEngine(keys_path=keys_path)

# Fix the string values to consider in order to avoid flaky tests (one of the checks requires to
# have a data-frame with at least 2 unique string values)
str_values = ["apple", "orange", "watermelon", "cherry", "banana"]

pandas_df = generate_pandas_dataframe(
feat_name=feat_name, index_name=selected_column, float_min=float_min, float_max=float_max
feat_name=feat_name,
index_name=selected_column,
indexes=len(str_values),
str_values=str_values,
float_min=float_min,
float_max=float_max,
)

schema_int_column = {f"{feat_name}_int_1": {None: None}}
Expand Down

0 comments on commit bc690e6

Please sign in to comment.