Skip to content

Commit

Permalink
Implement river drift detector
Browse files Browse the repository at this point in the history
  • Loading branch information
ti1uan committed Feb 17, 2024
1 parent 2df6469 commit 313b138
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
35 changes: 35 additions & 0 deletions src/drift_detector/river_drift_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from drift_detector.base_drift_detector import BaseDriftDetector
from river import drift
from numpy import array, mean
from typing import Callable


class RiverDriftDetector(BaseDriftDetector):
def __init__(
self,
drift_detect_algo: str = 'ADWIN',
agg_func: Callable[[array], float] = lambda x: mean(x)
) -> None:
super().__init__()
if drift_detect_algo == 'ADWIN':
self.drift_detector = drift.ADWIN()
else:
raise ValueError(f"Support for algorithm {drift_detect_algo} not implemented yet")

if not callable(agg_func):
raise TypeError("Aggregation function must be a callable.")
self.agg_func = agg_func

def is_drifted(self, feat_vec: array) -> bool:
"""
Check if the given feature vector indicates drift.
Parameters:
feat_vec (array): The feature vector to be checked for drift.
Returns:
bool: True if the feature vector indicates drift, False otherwise.
"""
val = self.agg_func(feat_vec)
self.drift_detector.update(val)
return self.drift_detector.drift_detected
Empty file removed tests/drift_detector/.gitkeep
Empty file.
27 changes: 27 additions & 0 deletions tests/drift_detector/test_river_drift_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from unittest.mock import patch
from numpy import mean, array
from drift_detector.river_drift_detector import RiverDriftDetector

def test_initialization():
with patch('river.drift.ADWIN') as MockADWIN:
RiverDriftDetector()
MockADWIN.assert_called_once()

def test_agg_func_usage():
test_data = array([1, 2, 3, 4, 5])
custom_agg_func = lambda x: sum(x) / len(x) # Same as mean
detector = RiverDriftDetector(agg_func=custom_agg_func)
assert detector.agg_func(test_data) == mean(test_data), "Aggregation function doesn't work as expected."

@patch('river.drift.ADWIN')
def test_is_drifted(MockADWIN):
mock_adwin_instance = MockADWIN.return_value
mock_adwin_instance.drift_detected = False
detector = RiverDriftDetector()

# No drift
assert not detector.is_drifted(array([1, 2, 3])), "Shouldn't detect drift."

# Drift exists
mock_adwin_instance.drift_detected = True
assert detector.is_drifted(array([4, 5, 6])), "Should detect drift."

0 comments on commit 313b138

Please sign in to comment.