Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Included RSAST as a alternative to SAST (2.0) #1383

Merged
merged 39 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fc2f824
included rsast
nirojasva Mar 26, 2024
deeddbf
updated transformer and classifier
nirojasva Mar 29, 2024
2c0a41d
included transformer and classifier
nirojasva Apr 1, 2024
ddd830a
updated rsast tranformer
nirojasva Apr 1, 2024
73d8eac
deleted example
nirojasva Apr 7, 2024
cf6582d
updated init
nirojasva Apr 7, 2024
f5eab2d
included LearningShapeletClassifier
nirojasva Apr 7, 2024
85a8dea
Merge branch 'main' of https://github.com/nirojasva/aeon
nirojasva Apr 7, 2024
2ec5ab3
updated format comments
nirojasva Apr 8, 2024
c08038b
corrected spaces
nirojasva Apr 8, 2024
8728d92
corrected identation
nirojasva Apr 8, 2024
1dc7903
updated identation
nirojasva Apr 8, 2024
d9dfda5
updated identation
nirojasva Apr 8, 2024
0c21b5f
corrected identation
nirojasva Apr 8, 2024
54dafa8
updated identation
nirojasva Apr 8, 2024
a1d1ece
updated identation
nirojasva Apr 8, 2024
7bc3df1
updated identation
nirojasva Apr 8, 2024
f86c465
excluded max acf and max pacf
nirojasva Apr 8, 2024
2d278ec
updated identation
nirojasva Apr 8, 2024
246b0ad
updated identation
nirojasva Apr 8, 2024
ca59b2c
updated identation
nirojasva Apr 8, 2024
52a1d33
updated identation
nirojasva Apr 8, 2024
2f7e3b4
updated identation
nirojasva Apr 8, 2024
395ece6
updated identation
nirojasva Apr 8, 2024
846472f
updated identation
nirojasva Apr 8, 2024
013af53
update packages
nirojasva Apr 8, 2024
b7ad0a7
included tag in transformer
nirojasva Apr 13, 2024
1182a3a
moved the import libraries
nirojasva Apr 13, 2024
e1227fd
included brackets
nirojasva Apr 13, 2024
85bd62c
moved libraries to fit function
nirojasva Apr 13, 2024
5996a92
deleted spaces
nirojasva Apr 13, 2024
c5539f6
updated spaces
nirojasva Apr 13, 2024
b525d88
updated identation
nirojasva Apr 13, 2024
fa80c26
included library statmodel in rsast classifier
nirojasva Apr 14, 2024
a7fe63c
applied in Classifier: doctest: +SKIP
nirojasva Apr 16, 2024
2f5ad63
skip in # doctest: +SKIP
nirojasva Apr 16, 2024
859c4fe
using pre-commit
nirojasva Apr 17, 2024
60e68e2
using pre-commit
nirojasva Apr 17, 2024
f4167b3
updated changes requested for PR
nirojasva Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aeon/classification/shapelet_based/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
"ShapeletTransformClassifier",
"RDSTClassifier",
"SASTClassifier",
"RSASTClassifier",
"LearningShapeletClassifier",
]

from aeon.classification.shapelet_based._ls import LearningShapeletClassifier
from aeon.classification.shapelet_based._mrsqm import MrSQMClassifier
from aeon.classification.shapelet_based._rdst import RDSTClassifier
from aeon.classification.shapelet_based._rsast_classifier import RSASTClassifier
from aeon.classification.shapelet_based._sast_classifier import SASTClassifier
from aeon.classification.shapelet_based._stc import ShapeletTransformClassifier
158 changes: 158 additions & 0 deletions aeon/classification/shapelet_based/_rsast_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Random Scalable and Accurate Subsequence Transform (RSAST).

Pipeline classifier using the RSAST transformer and an sklearn classifier.
"""

__maintainer__ = ["nirojasva"]
__all__ = ["RSASTClassifier"]

import numpy as np
from sklearn.linear_model import RidgeClassifierCV
from sklearn.pipeline import make_pipeline

from aeon.base._base import _clone_estimator
from aeon.classification import BaseClassifier
from aeon.transformations.collection.shapelet_based import RSAST


class RSASTClassifier(BaseClassifier):
"""RSASTClassifier.

Classification pipeline using
Random Scalable and Accurate Subsequence Transform (RSAST) [1]_ transformer
and an sklearn classifier.

Parameters
----------
n_random_points: int default = 10 the number of initial random points to extract
len_method: string default="both" the type of statistical tool used to get the
length of shapelets. "both"=ACF&PACF, "ACF"=ACF, "PACF"=PACF,
"None"=Extract randomly any length from the TS
nb_inst_per_class : int default = 10
the number of reference time series to select per class
seed : int, default = None
the seed of the random generator
classifier : sklearn compatible classifier, default = None
if None, a RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)) is used.
n_jobs : int, default -1
Number of threads to use for the transform.


Reference
---------
.. [1] Varela, N. R., Mbouopda, M. F., & Nguifo, E. M. (2023). RSAST: Sampling
Shapelets for Time Series Classification.
https://hal.science/hal-04311309/

Examples
--------
>>> from aeon.classification.shapelet_based import RSASTClassifier
>>> from aeon.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train")
>>> X_test, y_test = load_unit_test(split="test")
>>> clf = RSASTClassifier() # doctest: +SKIP
>>> clf.fit(X_train, y_train) # doctest: +SKIP
RSASTClassifier(...)
>>> y_pred = clf.predict(X_test) # doctest: +SKIP
"""

_tags = {
"capability:multithreading": True,
"capability:multivariate": False,
"algorithm_type": "shapelet",
"python_dependencies": "statsmodels",
}

def __init__(
self,
n_random_points=10,
len_method="both",
nb_inst_per_class=10,
seed=None,
classifier=None,
n_jobs=-1,
):
super().__init__()
self.n_random_points = n_random_points
self.len_method = len_method
self.nb_inst_per_class = nb_inst_per_class
self.n_jobs = n_jobs
self.seed = seed
self.classifier = classifier

def _fit(self, X, y):
"""Fit RSASTClassifier to the training data.

Parameters
----------
X: np.ndarray shape (n_cases, n_channels, n_timepoints)
The training input samples.
y: array-like or list
The class values for X.

Return
------
self : RSASTClassifier
This pipeline classifier

"""
self._transformer = RSAST(
self.n_random_points,
self.len_method,
self.nb_inst_per_class,
self.seed,
self.n_jobs,
)

self._classifier = _clone_estimator(
(
RidgeClassifierCV(alphas=np.logspace(-3, 3, 10))
if self.classifier is None
else self.classifier
),
self.seed,
)

self._pipeline = make_pipeline(self._transformer, self._classifier)

self._pipeline.fit(X, y)

return self

def _predict(self, X):
"""Predict labels for the input.

Parameters
----------
X: np.ndarray shape (n_cases, n_channels, n_timepoints)
The training input samples.

Return
------
array-like or list
Predicted class labels.
"""
return self._pipeline.predict(X)

def _predict_proba(self, X):
"""Predict labels probabilities for the input.

Parameters
----------
X: np.ndarray shape (n_cases, n_channels, n_timepoints)
The training input samples.

Return
------
dists : np.ndarray shape (n_cases, n_timepoints)
Predicted class probabilities.
"""
m = getattr(self._classifier, "predict_proba", None)
if callable(m):
dists = self._pipeline.predict_proba(X)
else:
dists = np.zeros((X.shape[0], self.n_classes_))
preds = self._pipeline.predict(X)
for i in range(0, X.shape[0]):
dists[i, np.where(self.classes_ == preds[i])] = 1
return dists
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Shapelet based transformers."""

__all__ = ["RandomShapeletTransform", "RandomDilatedShapeletTransform", "SAST"]
__all__ = ["RandomShapeletTransform", "RandomDilatedShapeletTransform", "SAST", "RSAST"]

from aeon.transformations.collection.shapelet_based._dilated_shapelet_transform import (
RandomDilatedShapeletTransform,
)
from aeon.transformations.collection.shapelet_based._rsast import RSAST
from aeon.transformations.collection.shapelet_based._sast import SAST
from aeon.transformations.collection.shapelet_based._shapelet_transform import (
RandomShapeletTransform,
Expand Down
Loading