diff --git a/train/src/main.py b/train/src/main.py index 9fffacd..f8c4283 100644 --- a/train/src/main.py +++ b/train/src/main.py @@ -1147,11 +1147,12 @@ def train_batch_watcher_disable(): threading.Thread(target=train_batch_watcher_func, daemon=True).start() - def stop_on_train_batch_end_if_needed(*args, **kwargs): + def stop_on_batch_end_if_needed(*args, **kwargs): if app.app_is_stopped(): raise app.StopAppError("This error is expected.") - model.add_callback("on_train_batch_end", stop_on_train_batch_end_if_needed) + model.add_callback("on_train_batch_end", stop_on_batch_end_if_needed) + model.add_callback("on_val_batch_end", stop_on_batch_end_if_needed) with app.run_with_stop_app_error_suppression(): model.train(