diff --git a/src/fklearn/training/classification.py b/src/fklearn/training/classification.py index 6cb4784f..3e30b377 100644 --- a/src/fklearn/training/classification.py +++ b/src/fklearn/training/classification.py @@ -1,4 +1,4 @@ -from typing import List, Any +from typing import List, Any, Union import numpy as np import pandas as pd @@ -502,7 +502,7 @@ def lgbm_classification_learner(df: pd.DataFrame, learning_rate: float = 0.1, num_estimators: int = 100, extra_params: LogType = None, - categorical_features: List[str] = "auto", + categorical_features: Union[List[str], "auto"] = "auto", prediction_column: str = "prediction", weight_column: str = None, encode_extra_cols: bool = True) -> LearnerReturnType: