diff --git a/metalearners/_typing.py b/metalearners/_typing.py index 10a6f4d..5643f19 100644 --- a/metalearners/_typing.py +++ b/metalearners/_typing.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +import polars as pl import scipy.sparse as sps PredictMethod = Literal["predict", "predict_proba"] @@ -21,8 +22,8 @@ Features = Collection[str] | Collection[int] | None # ruff is not happy about the usage of Union. -Vector = Union[pd.Series, np.ndarray] # noqa -Matrix = Union[pd.DataFrame, np.ndarray, sps.csr_matrix] # noqa +Vector = Union[pd.Series, np.ndarray, pl.Series] # noqa +Matrix = Union[pd.DataFrame, np.ndarray, sps.csr_matrix, pl.DataFrame] # noqa class _ScikitModel(Protocol):