Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] New experimental module: imbalance in collection transformers #2498

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ does not apply:
- `segmentation`
- `similarity_search`
- `visualisation`
- `transformations.collection.imbalance`

| Overview | |
|-----------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
Expand Down
6 changes: 6 additions & 0 deletions aeon/transformations/collection/imbalance/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Supervised transformers to rebalance colelctions of time series."""

__all__ = ["ADASYN", "SMOTE"]

from aeon.transformations.collection.imbalance._adasyn import ADASYN
from aeon.transformations.collection.imbalance._smote import SMOTE
87 changes: 87 additions & 0 deletions aeon/transformations/collection/imbalance/_adasyn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""ADASYN over sampling algorithm."""

import numpy as np
from sklearn.utils import check_random_state

from aeon.transformations.collection.imbalance._smote import SMOTE

__maintainer__ = ["TonyBagnall"]
__all__ = ["ADASYN"]


class ADASYN(SMOTE):
"""
Over-sampling using Adaptive Synthetic Sampling (ADASYN).

Adaptation of imblearn.over_sampling.ADASYN
original authors:
# Guillaume Lemaitre <[email protected]>
# Christos Aridas
# License: MIT

This transformer extends SMOTE, but it generates different number of
samples depending on an estimate of the local distribution of the class
to be oversampled.
"""

def __init__(self, random_state=None, k_neighbors=5):
super().__init__(random_state=random_state, k_neighbors=k_neighbors)

def _transform(self, X, y=None):
X = np.squeeze(X, axis=1)
random_state = check_random_state(self.random_state)
X_resampled = [X.copy()]
y_resampled = [y.copy()]

# got the minority class label and the number needs to be generated
for class_sample, n_samples in self.sampling_strategy_.items():
if n_samples == 0:
continue
target_class_indices = np.flatnonzero(y == class_sample)
X_class = X[target_class_indices]

self.nn_.fit(X)
nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
# The ratio is computed using a one-vs-rest manner. Using majority
# in multi-class would lead to slightly different results at the
# cost of introducing a new parameter.
n_neighbors = self.nn_.n_neighbors - 1
ratio_nn = np.sum(y[nns] != class_sample, axis=1) / n_neighbors
if not np.sum(ratio_nn):
raise RuntimeError(
"Not any neigbours belong to the majority"
" class. This case will induce a NaN case"
" with a division by zero. ADASYN is not"
" suited for this specific dataset."
" Use SMOTE instead."
)
ratio_nn /= np.sum(ratio_nn)
n_samples_generate = np.rint(ratio_nn * n_samples).astype(int)
# rounding may cause new amount for n_samples
n_samples = np.sum(n_samples_generate)
if not n_samples:
raise ValueError(
"No samples will be generated with the provided ratio settings."
)

# the nearest neighbors need to be fitted only on the current class
# to find the class NN to generate new samples
self.nn_.fit(X_class)
nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]

enumerated_class_indices = np.arange(len(target_class_indices))
rows = np.repeat(enumerated_class_indices, n_samples_generate)
cols = random_state.choice(n_neighbors, size=n_samples)
diffs = X_class[nns[rows, cols]] - X_class[rows]
steps = random_state.uniform(size=(n_samples, 1))
X_new = X_class[rows] + steps * diffs

X_new = X_new.astype(X.dtype)
y_new = np.full(n_samples, fill_value=class_sample, dtype=y.dtype)
X_resampled.append(X_new)
y_resampled.append(y_new)
X_resampled = np.vstack(X_resampled)
y_resampled = np.hstack(y_resampled)

X_resampled = X_resampled[:, np.newaxis, :]
return X_resampled, y_resampled
222 changes: 222 additions & 0 deletions aeon/transformations/collection/imbalance/_smote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
"""SMOTE over sampling algorithm.

See more in imblearn.over_sampling.SMOTE
original authors:
# Guillaume Lemaitre <[email protected]>
# Fernando Nogueira
# Christos Aridas
# Dzianis Dudnik
# License: MIT
"""

from collections import OrderedDict

import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_random_state

from aeon.transformations.collection import BaseCollectionTransformer

__maintainer__ = ["TonyBagnall"]
__all__ = ["SMOTE"]


class SMOTE(BaseCollectionTransformer):
"""
Over-sampling using the Synthetic Minority Over-sampling TEchnique (SMOTE)[1]_.

An adaptation of the imbalance-learn implementation of SMOTE in
imblearn.over_sampling.SMOTE. sampling_strategy is sampling target by
targeting all classes but not the majority, which is directly expressed in
_fit.sampling_strategy.

Parameters
----------
k_neighbors : int, default=5
The number of nearest neighbors used to define the neighborhood of samples
to use to generate the synthetic time series.
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this case.
random_state : int, RandomState instance or None, default=None
If `int`, random_state is the seed used by the random number generator;
If `RandomState` instance, random_state is the random number generator;
If `None`, the random number generator is the `RandomState` instance used
by `np.random`.

See Also
--------
ADASYN

References
----------
.. [1] Chawla et al. SMOTE: synthetic minority over-sampling technique, Journal
of Artificial Intelligence Research 16(1): 321–357, 2002.
https://dl.acm.org/doi/10.5555/1622407.1622416
"""

_tags = {
"capability:multivariate": False,
"capability:unequal_length": False,
"requires_y": True,
}

def __init__(self, k_neighbors=5, random_state=None):
self.random_state = random_state
self.k_neighbors = k_neighbors
super().__init__()

def _fit(self, X, y=None):
# set the additional_neighbor required by SMOTE
self.nn_ = NearestNeighbors(n_neighbors=self.k_neighbors + 1)

# generate sampling target by targeting all classes except the majority
unique, counts = np.unique(y, return_counts=True)
target_stats = dict(zip(unique, counts))
n_sample_majority = max(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_majority - value
for (key, value) in target_stats.items()
if key != class_majority
}
self.sampling_strategy_ = OrderedDict(sorted(sampling_strategy.items()))
return self

def _transform(self, X, y=None):
# remove the channel dimension to be compatible with sklearn
X = np.squeeze(X, axis=1)
X_resampled = [X.copy()]
y_resampled = [y.copy()]

# got the minority class label and the number needs to be generated
for class_sample, n_samples in self.sampling_strategy_.items():
if n_samples == 0:
continue
target_class_indices = np.flatnonzero(y == class_sample)
X_class = X[target_class_indices]

self.nn_.fit(X_class)
nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
X_new, y_new = self._make_samples(
X_class, y.dtype, class_sample, X_class, nns, n_samples, 1.0
)
X_resampled.append(X_new)
y_resampled.append(y_new)
X_resampled = np.vstack(X_resampled)
y_resampled = np.hstack(y_resampled)
X_resampled = X_resampled[:, np.newaxis, :]
return X_resampled, y_resampled

def _make_samples(
self, X, y_dtype, y_type, nn_data, nn_num, n_samples, step_size=1.0, y=None
):
"""Make artificial samples constructed based on nearest neighbours.

Parameters
----------
X : np.ndarray
Shape (n_cases, n_timepoints), time series from which the new series will
be created.

y_dtype : dtype
The data type of the targets.

y_type : str or int
The minority target value, just so the function can return the
target values for the synthetic variables with correct length in
a clear format.

nn_data : ndarray of shape (n_samples_all, n_features)
Data set carrying all the neighbours to be used

nn_num : ndarray of shape (n_samples_all, k_nearest_neighbours)
The nearest neighbours of each sample in `nn_data`.

n_samples : int
The number of samples to generate.

step_size : float, default=1.0
The step size to create samples.

y : ndarray of shape (n_samples_all,), default=None
The true target associated with `nn_data`. Used by Borderline SMOTE-2 to
weight the distances in the sample generation process.

Returns
-------
X_new : {ndarray, sparse matrix} of shape (n_samples_new, n_features)
Synthetically generated samples.

y_new : ndarray of shape (n_samples_new,)
Target values for synthetic samples.
"""
random_state = check_random_state(self.random_state)
samples_indices = random_state.randint(low=0, high=nn_num.size, size=n_samples)

# np.newaxis for backwards compatability with random_state
steps = step_size * random_state.uniform(size=n_samples)[:, np.newaxis]
rows = np.floor_divide(samples_indices, nn_num.shape[1])
cols = np.mod(samples_indices, nn_num.shape[1])

X_new = self._generate_samples(X, nn_data, nn_num, rows, cols, steps, y_type, y)
y_new = np.full(n_samples, fill_value=y_type, dtype=y_dtype)
return X_new, y_new

def _generate_samples(
self, X, nn_data, nn_num, rows, cols, steps, y_type=None, y=None
):
r"""Generate a synthetic sample.

The rule for the generation is:

.. math::
\mathbf{s_{s}} = \mathbf{s_{i}} + \mathcal{u}(0, 1) \times
(\mathbf{s_{i}} - \mathbf{s_{nn}}) \,

where \mathbf{s_{s}} is the new synthetic samples, \mathbf{s_{i}} is
the current sample, \mathbf{s_{nn}} is a randomly selected neighbors of
\mathbf{s_{i}} and \mathcal{u}(0, 1) is a random number between [0, 1).

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Points from which the points will be created.

nn_data : ndarray of shape (n_samples_all, n_features)
Data set carrying all the neighbours to be used.

nn_num : ndarray of shape (n_samples_all, k_nearest_neighbours)
The nearest neighbours of each sample in `nn_data`.

rows : ndarray of shape (n_samples,), dtype=int
Indices pointing at feature vector in X which will be used
as a base for creating new samples.

cols : ndarray of shape (n_samples,), dtype=int
Indices pointing at which nearest neighbor of base feature vector
will be used when creating new samples.

steps : ndarray of shape (n_samples,), dtype=float
Step sizes for new samples.

y_type : str, int or None, default=None
Class label of the current target classes for which we want to generate
samples.

y : ndarray of shape (n_samples_all,), default=None
The true target associated with `nn_data`. Used by Borderline SMOTE-2 to
weight the distances in the sample generation process.

Returns
-------
X_new : {ndarray, sparse matrix} of shape (n_samples, n_features)
Synthetically generated samples.
"""
diffs = nn_data[nn_num[rows, cols]] - X[rows]
if y is not None: # only entering for BorderlineSMOTE-2
random_state = check_random_state(self.random_state)
mask_pair_samples = y[nn_num[rows, cols]] != y_type
diffs[mask_pair_samples] *= random_state.uniform(
low=0.0, high=0.5, size=(mask_pair_samples.sum(), 1)
)
X_new = X[rows] + steps * diffs
return X_new.astype(X.dtype)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test resampling transformers."""
61 changes: 61 additions & 0 deletions aeon/transformations/collection/imbalance/tests/test_adasyn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Test ADASYN oversampler ported from imblearn."""

import numpy as np
import pytest

from aeon.testing.data_generation import make_example_3d_numpy
from aeon.transformations.collection.imbalance import ADASYN
from aeon.utils.validation._dependencies import _check_soft_dependencies


def test_adasyn():
"""Test the ADASYN class.

This function creates a 3D numpy array, applies
ADASYN using the ADASYN class, and asserts that the
transformed data has a balanced number of samples.
ADASYN is a variant of SMOTE that generates synthetic samples,
but it focuses on generating samples near the decision boundary.
Therefore, sometimes, it may generate more or less samples than SMOTE,
which is why we only check if the number of samples is nearly balanced.
"""
n_samples = 100 # Total number of labels
majority_num = 90 # number of majority class
minority_num = n_samples - majority_num # number of minority class

X = np.random.rand(n_samples, 1, 10)
y = np.array([0] * majority_num + [1] * minority_num)

transformer = ADASYN()
transformer.fit(X, y)
res_X, res_y = transformer.transform(X, y)
_, res_count = np.unique(res_y, return_counts=True)

assert np.abs(len(res_X) - 2 * majority_num) < minority_num
assert np.abs(len(res_y) - 2 * majority_num) < minority_num
assert res_count[0] == majority_num
assert np.abs(res_count[0] - res_count[1]) < minority_num


@pytest.mark.skipif(
not _check_soft_dependencies(
"imbalanced-learn",
package_import_alias={"imbalanced-learn": "imblearn"},
severity="none",
),
reason="skip test if required soft dependency imbalanced-learn not available",
)
def test_equivalence_imbalance():
"""Test ported ADASYN code produces the same as imblearn version."""
from imblearn.over_sampling import ADASYN as imbADASYN

X, y = make_example_3d_numpy(n_cases=20, n_channels=1)
y = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
X = X.squeeze()
s1 = imbADASYN(random_state=49)
X2, y2 = s1.fit_resample(X, y)
s2 = ADASYN(random_state=49)
X3, y3 = s2.fit_transform(X, y)
X3 = X3.squeeze()
assert np.array_equal(y2, y3)
assert np.allclose(X2, X3, atol=1e-4)
Loading
Loading