Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Tidy, deprecation actions and testing for sklearn estimators #1701

Merged
merged 3 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions aeon/classification/feature_based/_fresh_prince.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
RotationForestClassifier.
"""

__maintainer__ = []
__maintainer__ = ["MatthewMiddlehurst"]
__all__ = ["FreshPRINCEClassifier"]


import numpy as np

from aeon.classification.base import BaseClassifier
Expand Down Expand Up @@ -98,7 +97,6 @@ def __init__(
self.n_cases_ = 0
self.n_channels_ = 0
self.n_timepoints_ = 0
self.transformed_data_ = []

self._rotf = None
self._tsfresh = None
Expand Down
68 changes: 43 additions & 25 deletions aeon/classification/sklearn/_continuous_interval_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
from typing import List, Tuple, Type, Union

import numpy as np
import pandas as pd
from numba import njit
from scipy.sparse import issparse
from sklearn import preprocessing
from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.exceptions import NotFittedError
from sklearn.utils import check_random_state
from sklearn.utils.multiclass import check_classification_targets


class _TreeNode:
Expand Down Expand Up @@ -273,7 +276,7 @@ def remaining_classes(distribution) -> bool:
return remaining_classes > 1


class ContinuousIntervalTree(BaseEstimator):
class ContinuousIntervalTree(ClassifierMixin, BaseEstimator):
"""Continuous interval tree (CIT) vector classifier (aka Time Series Tree).

The `Time Series Tree` described in the Time Series Forest (TSF) [1]_. A simple
Expand Down Expand Up @@ -369,17 +372,16 @@ def fit(self, X, y):
Changes state by creating a fitted model that updates attributes
ending in "_".
"""
if isinstance(X, np.ndarray) and len(X.shape) == 3 and X.shape[1] == 1:
X = np.reshape(X, (X.shape[0], -1))
elif not isinstance(X, np.ndarray) or len(X.shape) > 2:
raise ValueError(
"ContinuousIntervalTree is not a time series classifier. "
"A valid sklearn input such as a 2d numpy array is required."
"Sparse input formats are currently not supported."
)
# data processing
X = self._check_X(X)
X, y = self._validate_data(
X=X, y=y, ensure_min_samples=2, force_all_finite="allow-nan"
X=X,
y=y,
ensure_min_samples=2,
force_all_finite="allow-nan",
accept_sparse=False,
)
check_classification_targets(y)

self.n_cases_, self.n_atts_ = X.shape
self.classes_ = np.unique(y)
Expand Down Expand Up @@ -435,12 +437,8 @@ def predict(self, X):
y : array-like, shape = [n_cases]
Predicted class labels.
"""
rng = check_random_state(self.random_state)
return np.array(
[
self.classes_[int(rng.choice(np.flatnonzero(prob == prob.max())))]
for prob in self.predict_proba(X)
]
[self.classes_[int(np.argmax(prob))] for prob in self.predict_proba(X)]
)

def predict_proba(self, X):
Expand All @@ -466,15 +464,11 @@ def predict_proba(self, X):
if self.n_classes_ == 1:
return np.repeat([[1]], X.shape[0], axis=0)

if isinstance(X, np.ndarray) and len(X.shape) == 3 and X.shape[1] == 1:
X = np.reshape(X, (X.shape[0], -1))
elif not isinstance(X, np.ndarray) or len(X.shape) > 2:
raise ValueError(
"ContinuousIntervalTree is not a time series classifier. "
"A valid sklearn input such as a 2d numpy array is required."
"Sparse input formats are currently not supported."
)
X = self._validate_data(X=X, reset=False, force_all_finite="allow-nan")
# data processing
X = self._check_X(X)
X = self._validate_data(
X=X, reset=False, force_all_finite="allow-nan", accept_sparse=False
)

dists = np.zeros((X.shape[0], self.n_classes_))
for i in range(X.shape[0]):
Expand All @@ -500,6 +494,30 @@ def _find_splits_gain(self, node: Type[_TreeNode], splits: list, gains: list):
if next_node.best_split > -1:
self._find_splits_gain(next_node, splits, gains)

def _check_X(self, X):
if issparse(X):
return X

msg = (
"ContinuousIntervalTree is not a time series classifier. "
"A valid sklearn input such as a 2d numpy array is required."
"Sparse input formats are currently not supported."
)
if isinstance(X, pd.DataFrame):
X = X.to_numpy()
else:
try:
X = np.array(X)
except Exception:
raise ValueError(msg)

if isinstance(X, np.ndarray) and len(X.shape) == 3 and X.shape[1] == 1:
X = np.reshape(X, (X.shape[0], -1))
elif not isinstance(X, np.ndarray) or len(X.shape) > 2:
raise ValueError(msg)

return X


@njit(fastmath=True, cache=True)
def _entropy(x, s: int) -> float:
Expand Down
92 changes: 41 additions & 51 deletions aeon/classification/sklearn/_rotation_forest_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@
__all__ = ["RotationForestClassifier"]

import time
import warnings
from typing import Type, Union
from typing import Optional, Type, Union

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from sklearn.base import BaseEstimator
from scipy.sparse import issparse
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import check_random_state
from sklearn.utils.multiclass import check_classification_targets

from aeon.base._base import _clone_estimator
from aeon.utils.validation import check_n_jobs


class RotationForestClassifier(BaseEstimator):
class RotationForestClassifier(ClassifierMixin, BaseEstimator):
"""
A rotation forest (RotF) vector classifier.

Expand Down Expand Up @@ -52,11 +53,6 @@ class RotationForestClassifier(BaseEstimator):
Default of `0` means ``n_estimators`` is used.
contract_max_n_estimators : int, default=500
Max number of estimators to build when ``time_limit_in_minutes`` is set.
save_transformed_data : bool, default=False
Save the data transformed in fit.

Deprecated and will be removed in v0.10.0. Use fit_predict and fit_predict_proba
to generate train estimates instead. transformed_data_ will also be removed.
n_jobs : int, default=1
The number of jobs to run in parallel for both ``fit`` and ``predict``.
`-1` means using all processors.
Expand Down Expand Up @@ -106,10 +102,9 @@ def __init__(
min_group: int = 3,
max_group: int = 3,
remove_proportion: float = 0.5,
base_estimator: Union[Type[BaseEstimator], None] = None,
time_limit_in_minutes: int = 0.0,
base_estimator: Optional[Type[BaseEstimator]] = None,
time_limit_in_minutes: float = 0.0,
contract_max_n_estimators: int = 500,
save_transformed_data: bool = "deprecated",
n_jobs: int = 1,
random_state: Union[int, Type[np.random.RandomState], None] = None,
):
Expand All @@ -123,15 +118,6 @@ def __init__(
self.n_jobs = n_jobs
self.random_state = random_state

# TODO remove 'save_transformed_data' and 'transformed_data_' in v0.10.0
self.save_transformed_data = save_transformed_data
if save_transformed_data != "deprecated":
warnings.warn(
"the save_transformed_data parameter is deprecated and will be"
"removed in v0.10.0. transformed_data_ will also be removed.",
stacklevel=2,
)

super().__init__()

def fit(self, X, y):
Expand Down Expand Up @@ -169,12 +155,8 @@ def predict(self, X) -> np.ndarray:
y : array-like, shape = [n_cases]
Predicted class labels.
"""
rng = check_random_state(self.random_state)
return np.array(
[
self.classes_[int(rng.choice(np.flatnonzero(prob == prob.max())))]
for prob in self.predict_proba(X)
]
[self.classes_[int(np.argmax(prob))] for prob in self.predict_proba(X)]
)

def predict_proba(self, X) -> np.ndarray:
Expand All @@ -190,7 +172,7 @@ def predict_proba(self, X) -> np.ndarray:
y : array-like, shape = [n_cases, n_classes_]
Predicted probabilities using the ordering in classes_.
"""
if not self._is_fitted:
if not hasattr(self, "_is_fitted") or not self._is_fitted:
from sklearn.exceptions import NotFittedError

raise NotFittedError(
Expand All @@ -202,17 +184,9 @@ def predict_proba(self, X) -> np.ndarray:
if self.n_classes_ == 1:
return np.repeat([[1]], X.shape[0], axis=0)

if isinstance(X, np.ndarray) and len(X.shape) == 3 and X.shape[1] == 1:
X = np.reshape(X, (X.shape[0], -1))
elif isinstance(X, pd.DataFrame) and len(X.shape) == 2:
X = X.to_numpy()
elif not isinstance(X, np.ndarray) or len(X.shape) > 2:
raise ValueError(
"RotationForestClassifier is not a time series classifier. "
"A valid sklearn input such as a 2d numpy array is required."
"Sparse input formats are currently not supported."
)
X = self._validate_data(X=X, reset=False)
# data processing
X = self._check_X(X)
X = self._validate_data(X=X, reset=False, accept_sparse=False)

# replace missing values with 0 and remove useless attributes
X = X[:, self._useful_atts]
Expand Down Expand Up @@ -257,10 +231,9 @@ def fit_predict(self, X, y) -> np.ndarray:
-----
Changes state by creating a fitted model that updates attributes ending in "_".
"""
rng = check_random_state(self.random_state)
return np.array(
[
self.classes_[int(rng.choice(np.flatnonzero(prob == prob.max())))]
self.classes_[int(np.argmax(prob))]
for prob in self.fit_predict_proba(X, y)
]
)
Expand Down Expand Up @@ -318,17 +291,10 @@ def fit_predict_proba(self, X, y) -> np.ndarray:
return results

def _fit_rotf(self, X, y, save_transformed_data: bool = False):
if isinstance(X, np.ndarray) and len(X.shape) == 3 and X.shape[1] == 1:
X = np.reshape(X, (X.shape[0], -1))
elif isinstance(X, pd.DataFrame) and len(X.shape) == 2:
X = X.to_numpy()
elif not isinstance(X, np.ndarray) or len(X.shape) > 2:
raise ValueError(
"RotationForestClassifier is not a time series classifier. "
"A valid sklearn input such as a 2d numpy array is required."
"Sparse input formats are currently not supported."
)
X, y = self._validate_data(X=X, y=y, ensure_min_samples=2)
# data processing
X = self._check_X(X)
X, y = self._validate_data(X=X, y=y, ensure_min_samples=2, accept_sparse=False)
check_classification_targets(y)

self._n_jobs = check_n_jobs(self.n_jobs)

Expand Down Expand Up @@ -558,3 +524,27 @@ def _generate_groups(self, rng: Type[np.random.RandomState]):
current_attribute += 1

return groups

def _check_X(self, X):
if issparse(X):
return X

msg = (
"RotationForestClassifier is not a time series classifier. "
"A valid sklearn input such as a 2d numpy array is required."
"Sparse input formats are currently not supported."
)
if isinstance(X, pd.DataFrame):
X = X.to_numpy()
else:
try:
X = np.array(X)
except Exception:
raise ValueError(msg)

if isinstance(X, np.ndarray) and len(X.shape) == 3 and X.shape[1] == 1:
X = np.reshape(X, (X.shape[0], -1))
elif not isinstance(X, np.ndarray) or len(X.shape) > 2:
raise ValueError(msg)

return X
18 changes: 13 additions & 5 deletions aeon/regression/feature_based/_fresh_prince.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ class FreshPRINCERegressor(BaseRegressor):
chunksize : int or None, default=None
Number of series processed in each parallel TSFresh job, should be optimised
for efficient parallelisation.
random_state : int or None, default=None
Seed for random, integer.
random_state : int, RandomState instance or None, default=None
If `int`, random_state is the seed used by the random number generator;
If `RandomState` instance, random_state is the random number generator;
If `None`, the random number generator is the `RandomState` instance used
by `np.random`.

See Also
--------
Expand All @@ -50,6 +53,12 @@ class FreshPRINCERegressor(BaseRegressor):
scalable hypothesis tests (tsfresh-a python package)." Neurocomputing 307
(2018): 72-77.
https://www.sciencedirect.com/science/article/pii/S0925231218304843
.. [2] Middlehurst, M., Bagnall, A. "The FreshPRINCE: A Simple Transformation
Based Pipeline Time Series Classifier." In: El Yacoubi, M., Granger, E.,
Yuen, P.C., Pal, U., Vincent, N. (eds) Pattern Recognition and Artificial
Intelligence. ICPRAI 2022. Lecture Notes in Computer Science, vol 13364.
Springer, Cham. (2022).
https://link.springer.com/chapter/10.1007/978-3-031-09282-4_13

Examples
--------
Expand Down Expand Up @@ -116,9 +125,8 @@ def _fit(self, X, y):
Changes state by creating a fitted model that updates attributes
ending in "_" and sets is_fitted flag to True.
"""
self.transformed_data_ = self._fit_fp_shared(X, y)
self._rotf.fit(self.transformed_data_, y)

X_t = self._fit_fp_shared(X, y)
self._rotf.fit(X_t, y)
return self

def _predict(self, X) -> np.ndarray:
Expand Down
Loading