-
Notifications
You must be signed in to change notification settings - Fork 142
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Added ROCKAD anomaly detector to aeon (#2376)
* Added ROCKAD anomaly detector to aeon * Added ROCKAD to anomaly_detection.rst * Empty commit for CI * Automatic `pre-commit` fixes * Fix newline at end of file * adopted code to fit refactored Rocket arguments * added "capability:multithreading": True to _tags * Catch power transform and disable it if it results in error * Added fallback when power transform fails, added ValueError if number of windows is smaller than n.neighors and other general improvements * Added private attribute for power_transform activation/deactivation * added tests for kneighbors check, adapted univariate and multivariate tests for changes in ROCKAD * Removed pandas, set rocket normalise default to False, use transposed data for power transform to prevent bracket error, set cleaned up code. Set standardize=True on power transform so that StandardScaler can be removed. * Added test for power transform failure * removed transpose, added back user warning when power transform fails, added comments for clarity, other small improvements * removed noop * removed inf_columns_ check * moved parent class init and n_jobs to have consistent structure --------- Co-authored-by: MatthewMiddlehurst <[email protected]> Co-authored-by: pattplatt <[email protected]>
- Loading branch information
1 parent
39fe7b4
commit fa27298
Showing
4 changed files
with
340 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
"""ROCKAD anomaly detector.""" | ||
|
||
__all__ = ["ROCKAD"] | ||
|
||
import warnings | ||
from typing import Optional | ||
|
||
import numpy as np | ||
from sklearn.neighbors import NearestNeighbors | ||
from sklearn.preprocessing import PowerTransformer | ||
from sklearn.utils import resample | ||
|
||
from aeon.anomaly_detection.base import BaseAnomalyDetector | ||
from aeon.transformations.collection.convolution_based import Rocket | ||
from aeon.utils.windowing import reverse_windowing, sliding_windows | ||
|
||
|
||
class ROCKAD(BaseAnomalyDetector): | ||
""" | ||
ROCKET-based Anomaly Detector (ROCKAD). | ||
ROCKAD leverages the ROCKET transformation for feature extraction from | ||
time series data and applies the scikit learn k-nearest neighbors (k-NN) | ||
approach with bootstrap aggregation for robust anomaly detection. | ||
After windowing, the data gets transformed into the ROCKET feature space. | ||
Then the windows are compared based on the feature space by | ||
finding the nearest neighbours. | ||
This class supports both univariate and multivariate time series and | ||
provides options for normalizing features, applying power transformations, | ||
and customizing the distance metric. | ||
Parameters | ||
---------- | ||
n_estimators : int, default=10 | ||
Number of k-NN estimators to use in the bootstrap aggregation. | ||
n_kernels : int, default=100 | ||
Number of kernels to use in the ROCKET transformation. | ||
normalise : bool, default=False | ||
Whether to normalize the ROCKET-transformed features. | ||
n_neighbors : int, default=5 | ||
Number of neighbors to use for the k-NN algorithm. | ||
n_jobs : int, default=1 | ||
Number of parallel jobs to use for the k-NN algorithm and ROCKET transformation. | ||
metric : str, default="euclidean" | ||
Distance metric to use for the k-NN algorithm. | ||
power_transform : bool, default=True | ||
Whether to apply a power transformation (Yeo-Johnson) to the features. | ||
window_size : int, default=10 | ||
Size of the sliding window for segmenting input time series data. | ||
stride : int, default=1 | ||
Step size for moving the sliding window over the time series data. | ||
random_state : int, default=42 | ||
Random seed for reproducibility. | ||
Attributes | ||
---------- | ||
rocket_transformer_ : Optional[Rocket] | ||
Instance of the ROCKET transformer used to extract features, set after fitting. | ||
list_baggers_ : Optional[list[NearestNeighbors]] | ||
List containing k-NN estimators used for anomaly scoring, set after fitting. | ||
power_transformer_ : PowerTransformer | ||
Transformer used to apply power transformation to the features. | ||
""" | ||
|
||
_tags = { | ||
"capability:univariate": True, | ||
"capability:multivariate": True, | ||
"capability:missing_values": False, | ||
"capability:multithreading": True, | ||
"fit_is_empty": False, | ||
} | ||
|
||
def __init__( | ||
self, | ||
n_estimators=10, | ||
n_kernels=100, | ||
normalise=False, | ||
n_neighbors=5, | ||
metric="euclidean", | ||
power_transform=True, | ||
window_size: int = 10, | ||
stride: int = 1, | ||
n_jobs=1, | ||
random_state=42, | ||
): | ||
|
||
self.n_estimators = n_estimators | ||
self.n_kernels = n_kernels | ||
self.normalise = normalise | ||
self.n_neighbors = n_neighbors | ||
self.n_jobs = n_jobs | ||
self.metric = metric | ||
self.power_transform = power_transform | ||
self.window_size = window_size | ||
self.stride = stride | ||
self.random_state = random_state | ||
|
||
self.rocket_transformer_: Optional[Rocket] = None | ||
self.list_baggers_: Optional[list[NearestNeighbors]] = None | ||
self.power_transformer_: Optional[PowerTransformer] = None | ||
|
||
super().__init__(axis=0) | ||
|
||
def _fit(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> "ROCKAD": | ||
self._check_params(X) | ||
# X: (n_timepoints, 1) because __init__(axis==0) | ||
_X, _ = sliding_windows( | ||
X, window_size=self.window_size, stride=self.stride, axis=0 | ||
) | ||
# _X: (n_windows, window_size) | ||
self._inner_fit(_X) | ||
|
||
return self | ||
|
||
def _check_params(self, X: np.ndarray) -> None: | ||
if self.window_size < 1 or self.window_size > X.shape[0]: | ||
raise ValueError( | ||
"The window size must be at least 1 and at most the length of the " | ||
"time series." | ||
) | ||
|
||
if self.stride < 1 or self.stride > self.window_size: | ||
raise ValueError( | ||
"The stride must be at least 1 and at most the window size." | ||
) | ||
|
||
if int((X.shape[0] - self.window_size) / self.stride + 1) < self.n_neighbors: | ||
raise ValueError( | ||
f"Window count ({int((X.shape[0]-self.window_size)/self.stride+1)}) " | ||
f"has to be larger than n_neighbors ({self.n_neighbors})." | ||
"Please choose a smaller n_neighbors value or increase window count " | ||
"by choosing a smaller window size or larger stride." | ||
) | ||
|
||
def _inner_fit(self, X: np.ndarray) -> None: | ||
|
||
self.rocket_transformer_ = Rocket( | ||
n_kernels=self.n_kernels, | ||
normalise=self.normalise, | ||
n_jobs=self.n_jobs, | ||
random_state=self.random_state, | ||
) | ||
# X: (n_windows, window_size) | ||
Xt = self.rocket_transformer_.fit_transform(X) | ||
# XT: (n_cases, n_kernels*2) | ||
Xt = Xt.astype(np.float64) | ||
|
||
if self.power_transform: | ||
self.power_transformer_ = PowerTransformer() | ||
try: | ||
Xtp = self.power_transformer_.fit_transform(Xt) | ||
|
||
except Exception: | ||
warnings.warn( | ||
"Power Transform failed and thus has been disabled. " | ||
"Try increasing the window size.", | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
self.power_transformer_ = None | ||
Xtp = Xt | ||
else: | ||
Xtp = Xt | ||
|
||
self.list_baggers_ = [] | ||
|
||
for idx_estimator in range(self.n_estimators): | ||
# Initialize estimator | ||
estimator = NearestNeighbors( | ||
n_neighbors=self.n_neighbors, | ||
n_jobs=self.n_jobs, | ||
metric=self.metric, | ||
algorithm="kd_tree", | ||
) | ||
# Bootstrap Aggregation | ||
Xtp_scaled_sample = resample( | ||
Xtp, | ||
replace=True, | ||
n_samples=None, | ||
random_state=self.random_state + idx_estimator, | ||
stratify=None, | ||
) | ||
|
||
# Fit estimator and append to estimator list | ||
estimator.fit(Xtp_scaled_sample) | ||
self.list_baggers_.append(estimator) | ||
|
||
def _predict(self, X) -> np.ndarray: | ||
|
||
_X, padding = sliding_windows( | ||
X, window_size=self.window_size, stride=self.stride, axis=0 | ||
) | ||
|
||
point_anomaly_scores = self._inner_predict(_X, padding) | ||
|
||
return point_anomaly_scores | ||
|
||
def _fit_predict(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray: | ||
self._check_params(X) | ||
_X, padding = sliding_windows( | ||
X, window_size=self.window_size, stride=self.stride, axis=0 | ||
) | ||
|
||
self._inner_fit(_X) | ||
point_anomaly_scores = self._inner_predict(_X, padding) | ||
return point_anomaly_scores | ||
|
||
def _inner_predict(self, X: np.ndarray, padding: int) -> np.ndarray: | ||
|
||
anomaly_scores = self._predict_proba(X) | ||
|
||
point_anomaly_scores = reverse_windowing( | ||
anomaly_scores, self.window_size, np.nanmean, self.stride, padding | ||
) | ||
|
||
point_anomaly_scores = (point_anomaly_scores - point_anomaly_scores.min()) / ( | ||
point_anomaly_scores.max() - point_anomaly_scores.min() | ||
) | ||
|
||
return point_anomaly_scores | ||
|
||
def _predict_proba(self, X): | ||
""" | ||
Predicts the probability of anomalies for the input data. | ||
Parameters | ||
---------- | ||
X (array-like): The input data. | ||
Returns | ||
------- | ||
np.ndarray: The predicted probabilities. | ||
""" | ||
y_scores = np.zeros((len(X), self.n_estimators)) | ||
# Transform into rocket feature space | ||
Xt = self.rocket_transformer_.transform(X) | ||
|
||
Xt = Xt.astype(np.float64) | ||
|
||
if self.power_transformer_ is not None: | ||
# Power Transform using yeo-johnson | ||
Xtp = self.power_transformer_.transform(Xt) | ||
|
||
else: | ||
Xtp = Xt | ||
|
||
for idx, bagger in enumerate(self.list_baggers_): | ||
# Get scores from each estimator | ||
distances, _ = bagger.kneighbors(Xtp) | ||
|
||
# Compute mean distance of nearest points in window | ||
scores = distances.mean(axis=1).reshape(-1, 1) | ||
scores = scores.squeeze() | ||
|
||
y_scores[:, idx] = scores | ||
|
||
# Average the scores to get the final score for each time series | ||
y_scores = y_scores.mean(axis=1) | ||
|
||
return y_scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Tests for the ROCKAD anomaly detector.""" | ||
|
||
import numpy as np | ||
import pytest | ||
from sklearn.utils import check_random_state | ||
|
||
from aeon.anomaly_detection import ROCKAD | ||
|
||
|
||
def test_rockad_univariate(): | ||
"""Test ROCKAD univariate output.""" | ||
rng = check_random_state(seed=2) | ||
series = rng.normal(size=(100,)) | ||
series[50:58] -= 5 | ||
|
||
ad = ROCKAD( | ||
n_estimators=100, | ||
n_kernels=10, | ||
n_neighbors=9, | ||
power_transform=True, | ||
window_size=20, | ||
stride=1, | ||
) | ||
|
||
pred = ad.fit_predict(series, axis=0) | ||
|
||
assert pred.shape == (100,) | ||
assert pred.dtype == np.float64 | ||
assert 50 <= np.argmax(pred) <= 58 | ||
|
||
|
||
def test_rockad_multivariate(): | ||
"""Test ROCKAD multivariate output.""" | ||
rng = check_random_state(seed=2) | ||
series = rng.normal(size=(100, 3)) | ||
series[50:58, 0] -= 5 | ||
series[87:90, 1] += 0.1 | ||
|
||
ad = ROCKAD( | ||
n_estimators=1000, | ||
n_kernels=100, | ||
n_neighbors=20, | ||
power_transform=True, | ||
window_size=10, | ||
stride=1, | ||
) | ||
|
||
pred = ad.fit_predict(series, axis=0) | ||
|
||
assert pred.shape == (100,) | ||
assert pred.dtype == np.float64 | ||
assert 50 <= np.argmax(pred) <= 58 | ||
|
||
|
||
def test_rockad_incorrect_input(): | ||
"""Test ROCKAD incorrect input.""" | ||
rng = check_random_state(seed=2) | ||
series = rng.normal(size=(100,)) | ||
|
||
with pytest.raises(ValueError, match="The window size must be at least 1"): | ||
ad = ROCKAD(window_size=0) | ||
ad.fit_predict(series) | ||
with pytest.raises(ValueError, match="The stride must be at least 1"): | ||
ad = ROCKAD(stride=0) | ||
ad.fit_predict(series) | ||
with pytest.raises( | ||
ValueError, match=r"Window count .* has to be larger than n_neighbors .*" | ||
): | ||
ad = ROCKAD(stride=1, window_size=100) | ||
ad.fit_predict(series) | ||
with pytest.warns( | ||
UserWarning, match=r"Power Transform failed and thus has been disabled." | ||
): | ||
ad = ROCKAD(stride=1, window_size=5) | ||
ad.fit_predict(series) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ Detectors | |
MERLIN | ||
OneClassSVM | ||
PyODAdapter | ||
ROCKAD | ||
STOMP | ||
STRAY | ||
|
||
|