From 8fd25322b5a23530eb9798fecd66ac9b8e0ee10c Mon Sep 17 00:00:00 2001 From: almaz Date: Mon, 1 Apr 2024 13:11:08 +0200 Subject: [PATCH] test early stop --- train/src/main.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/train/src/main.py b/train/src/main.py index 810160b..1f6de20 100644 --- a/train/src/main.py +++ b/train/src/main.py @@ -1302,9 +1302,15 @@ def train_batch_watcher_disable(): threading.Thread(target=train_batch_watcher_func, daemon=True).start() - def stop_on_batch_end_if_needed(*args, **kwargs): - if app.is_stopped(): + def stop_on_batch_end_if_needed(trainer_validator, *args, **kwargs): + try: + app_is_stopped = app.is_stopped() + except: + app_is_stopped = True + if app_is_stopped: + trainer_validator.stop = True raise app.StopException("This error is expected.") + model.add_callback("on_train_batch_end", stop_on_batch_end_if_needed) model.add_callback("on_val_batch_end", stop_on_batch_end_if_needed)