Skip to content

Commit

Permalink
added retrain config
Browse files Browse the repository at this point in the history
  • Loading branch information
hahahannes committed Jun 18, 2024
1 parent a07a4cb commit d3056ce
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 25 deletions.
19 changes: 10 additions & 9 deletions algo/curve_anomaly/cont_det/online_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def __init__(
mlflow_url,
operator_id,
pipeline_id,
retrain_level,
retrain_interval
train_level,
train_interval,
retrain
):
super().__init__(data_path, init_median, first_data_time)
self.ml_trainer_url = ml_trainer_url
Expand All @@ -34,8 +35,9 @@ def __init__(
self.operator_id = operator_id
self.pipeline_id = pipeline_id
self.anomalies = []
self.retrain_level = retrain_level
self.retrain_interval = retrain_interval
self.train_level = train_level
self.train_interval = train_interval
self.retrain = retrain

def check(self, value, timestamp):
if self.first_data_time == None:
Expand Down Expand Up @@ -79,14 +81,13 @@ def check(self, value, timestamp):
return False, ''

def training_shall_start(self, timestamp):
# when there is enough data to train -> this will also trigger a retrain
util.logger.debug(f"Current Time: {timestamp} - Last Train Time: {self.last_training_time}")
if timestamp - self.last_training_time < pd.Timedelta(self.retrain_interval, self.retrain_level):
# Training shall start when there is enough initial data or when retraining is enabled
util.logger.debug(f"Current Time: {timestamp} - Last Train Time: {self.last_training_time} < {self.train_interval}{self.train_level}")
if timestamp - self.last_training_time < pd.Timedelta(self.train_interval, self.train_level):
util.logger.debug("Wait with training until enough data is collected")
return False

# TODO remove, only to not spawn endless jobs
if self.job_id:
if self.job_id and not self.retrain:
return

return True
Expand Down
12 changes: 7 additions & 5 deletions algo/curve_anomaly/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ def create_curve_detector(
device_id,
operator_id,
pipeline_id,
retrain_level,
retrain_interval
):
train_level,
train_interval,
retrain
):

data_path = os.path.join(data_path, "curve_explorer")
if device_type == "cont_device":
Expand All @@ -31,8 +32,9 @@ def create_curve_detector(
mlflow_url,
operator_id,
pipeline_id,
retrain_level,
retrain_interval
train_level,
train_interval,
retrain
)
else:
return OfflineTrainContCurveDetector(data_path, init_median, first_data_time)
Expand Down
15 changes: 9 additions & 6 deletions algo/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def __init__(
curve_detector_training_mode,
operator_id,
pipeline_id,
retrain_level,
retrain_interval
train_level,
train_interval,
retrain
):
self.active_detectors = []
self.device_id = device_id
Expand All @@ -45,8 +46,9 @@ def __init__(
self.data_path = data_path
self.operator_id = operator_id
self.pipeline_id = pipeline_id
self.retrain_level = retrain_level
self.retrain_interval = retrain_interval
self.train_level = train_level
self.train_interval = train_interval
self.retrain = retrain

if check_data_schema:
util.logger.info(f"{LOG_PREFIX}: Data Schema Detector is active")
Expand Down Expand Up @@ -88,8 +90,9 @@ def update_device_type(self, device_type):
self.device_id,
self.operator_id,
self.pipeline_id,
self.retrain_level,
self.retrain_interval
self.train_level,
self.train_interval,
self.retrain
)

self.active_detectors.append(self.Curve_Explorer)
Expand Down
12 changes: 7 additions & 5 deletions algo/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ class CustomConfig(Config):
check_consumption: bool = False
init_phase_length: float = 2
init_phase_level: str = "d"
retrain_interval: float = 14
retrain_level: str = "d"
train_interval: float = 14
train_level: str = "d"
retrain: bool = False
ml_trainer_url: str = "http://ml-trainer-svc.trainer:5000"
mlflow_url: str = "http://mlflow-svc.mlflow:5000"
curve_detector_training_mode: str = "offline"
Expand Down Expand Up @@ -79,7 +80,7 @@ def init(self, *args, **kwargs):
if not os.path.exists(self.config.data_path):
os.mkdir(self.config.data_path)

self.produce = lambda x: print(x) # TODO REMOVE!!!!
#self.produce = lambda x: print(x) uncomment for local testing to not pollute kafka topics when portforwarding to cluster is used

self.init_phase_duration = pd.Timedelta(self.config.init_phase_length, self.config.init_phase_level)
self.operator_start_time = pd.Timestamp(setup_operator_starttime(self.config.data_path)).tz_localize(None)
Expand Down Expand Up @@ -124,8 +125,9 @@ def get_device_detectors(self, input_ids):
self.config.curve_detector_training_mode,
self.get_operator_id(),
self.get_pipeline_id(),
self.config.retrain_level,
self.config.retrain_interval
self.config.train_level,
self.config.train_interval,
self.config.retrain
)
self.device_detectors[input_ids] = device_detector
return device_detector
Expand Down

0 comments on commit d3056ce

Please sign in to comment.