Skip to content

Commit

Permalink
Merge pull request #305 from IFCA-Advanced-Computing/feature-energy-d…
Browse files Browse the repository at this point in the history
…istance

Add Energy distance data drift method
  • Loading branch information
jaime-cespedes-sisniega authored Feb 13, 2024
2 parents 6183197 + 19c87e7 commit 622949f
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 3 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,9 @@ The currently implemented detectors are listed in the following table.
<td style="text-align: center; border: 1px solid grey; padding: 8px;"><a href="https://doi.org/10.1007/978-3-540-75488-6_27">Nishida and Yamauchi (2007)</a></td>
</tr>
<tr>
<td rowspan="16" style="text-align: center; border: 1px solid grey; padding: 8px;">Data drift</td>
<td rowspan="14" style="text-align: center; border: 1px solid grey; padding: 8px;">Batch</td>
<td rowspan="9" style="text-align: center; border: 1px solid grey; padding: 8px;">Distance based</td>
<td rowspan="17" style="text-align: center; border: 1px solid grey; padding: 8px;">Data drift</td>
<td rowspan="15" style="text-align: center; border: 1px solid grey; padding: 8px;">Batch</td>
<td rowspan="10" style="text-align: center; border: 1px solid grey; padding: 8px;">Distance based</td>
<td style="text-align: center; border: 1px solid grey; padding: 8px;">U</td>
<td style="text-align: center; border: 1px solid grey; padding: 8px;">N</td>
<td style="text-align: center; border: 1px solid grey; padding: 8px;">Anderson-Darling test</td>
Expand All @@ -355,6 +355,12 @@ The currently implemented detectors are listed in the following table.
<td style="text-align: center; border: 1px solid grey; padding: 8px;">Earth Mover's distance</td>
<td style="text-align: center; border: 1px solid grey; padding: 8px;"><a href="https://doi.org/10.1023/A:1026543900054">Rubner et al. (2000)</a></td>
</tr>
<tr>
<td style="text-align: center; border: 1px solid grey; padding: 8px;">U</td>
<td style="text-align: center; border: 1px solid grey; padding: 8px;">N</td>
<td style="text-align: center; border: 1px solid grey; padding: 8px;">Energy distance</td>
<td style="text-align: center; border: 1px solid grey; padding: 8px;"><a href="https://doi.org/10.1016/j.jspi.2013.03.018">Székely et al. (2013)</a></td>
</tr>
<tr>
<td style="text-align: center; border: 1px solid grey; padding: 8px;">U</td>
<td style="text-align: center; border: 1px solid grey; padding: 8px;">N</td>
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_reference/detectors/data_drift/batch.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The {mod}`frouros.detectors.data_drift.batch` module contains batch data drift d
BhattacharyyaDistance
EMD
EnergyDistance
HellingerDistance
HINormalizedComplement
JS
Expand Down
2 changes: 2 additions & 0 deletions frouros/detectors/data_drift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ChiSquareTest,
CVMTest,
EMD,
EnergyDistance,
HellingerDistance,
HINormalizedComplement,
JS,
Expand All @@ -25,6 +26,7 @@
"ChiSquareTest",
"CVMTest",
"EMD",
"EnergyDistance",
"HellingerDistance",
"HINormalizedComplement",
"IncrementalKSTest",
Expand Down
2 changes: 2 additions & 0 deletions frouros/detectors/data_drift/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .distance_based import (
BhattacharyyaDistance,
EMD,
EnergyDistance,
HellingerDistance,
HINormalizedComplement,
JS,
Expand All @@ -25,6 +26,7 @@
"ChiSquareTest",
"CVMTest",
"EMD",
"EnergyDistance",
"HellingerDistance",
"HINormalizedComplement",
"JS",
Expand Down
2 changes: 2 additions & 0 deletions frouros/detectors/data_drift/batch/distance_based/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .bhattacharyya_distance import BhattacharyyaDistance
from .emd import EMD
from .energy_distance import EnergyDistance
from .hellinger_distance import HellingerDistance
from .hi_normalized_complement import HINormalizedComplement
from .js import JS
Expand All @@ -12,6 +13,7 @@
__all__ = [
"BhattacharyyaDistance",
"EMD",
"EnergyDistance",
"HellingerDistance",
"HINormalizedComplement",
"JS",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Energy Distance module."""

from typing import Optional, Union

import numpy as np # type: ignore
from scipy.stats import energy_distance # type: ignore

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.detectors.data_drift.base import UnivariateData
from frouros.detectors.data_drift.batch.distance_based.base import (
BaseDistanceBased,
DistanceResult,
)


class EnergyDistance(BaseDistanceBased):
"""EnergyDistance [szekely2013energy]_ detector.
:param callbacks: callbacks, defaults to None
:type callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]]
:param kwargs: additional keyword arguments to pass to scipy.stats.energy_distance
:type kwargs: Dict[str, Any]
:References:
.. [szekely2013energy] Székely, Gábor J., and Maria L. Rizzo.
"Energy statistics: A class of statistics based on distances."
Journal of statistical planning and inference 143.8 (2013): 1249-1272.
:Example:
>>> from frouros.detectors.data_drift import EnergyDistance
>>> import numpy as np
>>> np.random.seed(seed=31)
>>> X = np.random.normal(loc=0, scale=1, size=100)
>>> Y = np.random.normal(loc=1, scale=1, size=100)
>>> detector = EnergyDistance()
>>> _ = detector.fit(X=X)
>>> detector.compare(X=Y)[0]
DistanceResult(distance=0.8359206395514527)
""" # noqa: E501

def __init__( # noqa: D107
self,
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
**kwargs,
) -> None:
super().__init__(
statistical_type=UnivariateData(),
statistical_method=self._energy_distance,
statistical_kwargs=kwargs,
callbacks=callbacks,
)
self.kwargs = kwargs

def _distance_measure(
self,
X_ref: np.ndarray, # noqa: N803
X: np.ndarray, # noqa: N803
**kwargs,
) -> DistanceResult:
emd = self._energy_distance(X=X_ref, Y=X, **self.kwargs)
distance = DistanceResult(distance=emd)
return distance

@staticmethod
def _energy_distance(
X: np.ndarray, # noqa: N803
Y: np.ndarray,
**kwargs,
) -> float:
energy = energy_distance(
u_values=X.flatten(),
v_values=Y.flatten(),
**kwargs,
)
return energy
2 changes: 2 additions & 0 deletions frouros/tests/integration/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
BhattacharyyaDistance,
CVMTest,
EMD,
EnergyDistance,
HellingerDistance,
HINormalizedComplement,
JS,
Expand All @@ -48,6 +49,7 @@
[
(BhattacharyyaDistance, 0.55516059, 0.0),
(EMD, 3.85346006, 0.0),
(EnergyDistance, 2.11059982, 0.0),
(HellingerDistance, 0.74509099, 0.0),
(HINormalizedComplement, 0.78, 0.0),
(JS, 0.67010107, 0.0),
Expand Down
2 changes: 2 additions & 0 deletions frouros/tests/integration/test_data_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from frouros.detectors.data_drift.batch import (
BhattacharyyaDistance,
EMD,
EnergyDistance,
HellingerDistance,
HINormalizedComplement,
PSI,
Expand Down Expand Up @@ -64,6 +65,7 @@ def test_batch_distance_based_categorical(
"detector, expected_distance",
[
(EMD(), 3.85346006),
(EnergyDistance(), 2.11059982),
(JS(), 0.67010107),
(KL(), np.inf),
(HINormalizedComplement(), 0.78),
Expand Down

0 comments on commit 622949f

Please sign in to comment.