Skip to content

Commit

Permalink
Merge pull request #152 from ibm-granite/services_support
Browse files Browse the repository at this point in the history
improved error messages for incorrect inputs
  • Loading branch information
wgifford authored Oct 10, 2024
2 parents dd44f73 + d551722 commit c869df8
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,20 @@ def __init__(
y_cols = [y_cols]

if len(x_cols) > 0:
assert is_cols_in_df(data_df, x_cols), f"one or more {x_cols} is not in the list of data_df columns"
there, missing = is_cols_in_df(data_df, x_cols)
assert there, f"{missing} given in {x_cols} is not a valid column identifier in the data."

if len(y_cols) > 0:
assert is_cols_in_df(data_df, y_cols), f"one or more {y_cols} is not in the list of data_df columns"
there, missing = is_cols_in_df(data_df, y_cols)
assert there, f"{missing} given in {y_cols} is not a valid column identifier in the data."

if timestamp_column:
assert timestamp_column in list(
data_df.columns
), f"{timestamp_column} is not in the list of data_df columns"
assert timestamp_column not in x_cols, f"{timestamp_column} should not be in the list of x_cols"
), f"{timestamp_column} is not in the list of data column names provided {data_df.columns}"
assert (
timestamp_column not in x_cols
), f"{timestamp_column} can not be used as a timestamp column as it also appears in provided collection:{x_cols}."

self.data_df = data_df
self.datetime_col = timestamp_column
Expand Down Expand Up @@ -162,7 +166,8 @@ def __init__(
**kwargs,
):
if len(id_columns) > 0:
assert is_cols_in_df(data_df, id_columns), f"{id_columns} is not in the data_df columns"
there, missing = is_cols_in_df(data_df, id_columns)
assert there, f"{missing} given in {id_columns} is not a valid column in the data."

self.timestamp_column = timestamp_column
self.id_columns = id_columns
Expand Down Expand Up @@ -896,8 +901,8 @@ def is_cols_in_df(df: pd.DataFrame, cols: List[str]) -> bool:
"""
for col in cols:
if col not in list(df.columns):
return False
return True
return False, col
return True, None


if __name__ == "__main__":
Expand Down

0 comments on commit c869df8

Please sign in to comment.