diff --git a/kairon/shared/data/model_processor.py b/kairon/shared/data/model_processor.py index 5a4898e69..3604d9aa3 100644 --- a/kairon/shared/data/model_processor.py +++ b/kairon/shared/data/model_processor.py @@ -71,7 +71,7 @@ def handle_current_model_training(bot: Text, user: Text): raise AppException("No Enqueued model training present for this bot.") model_training_object = model_training.get() if model_training_object.status == EVENT_STATUS.INPROGRESS.value: - raise AppException("Previous model training in progress. Please wait for sometime.") + raise AppException("Previous model training in progress.") if model_training_object.status == EVENT_STATUS.ENQUEUED.value: ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.FAIL.value) diff --git a/tests/integration_test/services_test.py b/tests/integration_test/services_test.py index 06a61c29b..413bc0965 100644 --- a/tests/integration_test/services_test.py +++ b/tests/integration_test/services_test.py @@ -11900,8 +11900,53 @@ def test_add_training_example_case_insensitivity(): @responses.activate @patch("kairon.shared.data.model_processor.ModelProcessor.is_daily_training_limit_exceeded", autospec=True) -def test_retrain(mock_training_limit): +def test_retrain_with_no_enqueued_model_training(mock_training_limit): + mock_training_limit.return_value = False + event_url = urljoin(Utility.environment['events']['server_url'], f"/api/events/execute/{EventClass.model_training}") + responses.add( + "POST", event_url, json={"success": True, "message": "Event triggered successfully!"} + ) + response = client.post( + f"/api/bot/{pytest.bot}/retrain", + headers={"Authorization": pytest.token_type + " " + pytest.access_token}, + ) + actual = response.json() + print(actual) + assert not actual["success"] + assert actual["error_code"] == 422 + assert actual["data"] is None + assert actual["message"] == "No Enqueued model training present for this bot." + + +@responses.activate +@patch("kairon.shared.data.model_processor.ModelProcessor.handle_current_model_training", autospec=True) +@patch("kairon.shared.data.model_processor.ModelProcessor.is_daily_training_limit_exceeded", autospec=True) +def test_retrain_with_model_training_in_progress(mock_training_limit, mock_handle_current_model_training): + mock_handle_current_model_training.side_effect = AppException("Previous model training in progress.") + mock_training_limit.return_value = False + event_url = urljoin(Utility.environment['events']['server_url'], f"/api/events/execute/{EventClass.model_training}") + responses.add( + "POST", event_url, json={"success": True, "message": "Event triggered successfully!"} + ) + response = client.post( + f"/api/bot/{pytest.bot}/train", + headers={"Authorization": pytest.token_type + " " + pytest.access_token}, + ) + response = client.post( + f"/api/bot/{pytest.bot}/retrain", + headers={"Authorization": pytest.token_type + " " + pytest.access_token}, + ) + actual = response.json() + print(actual) + assert not actual["success"] + assert actual["error_code"] == 422 + assert actual["data"] is None + assert actual["message"] == "Previous model training in progress." + +@responses.activate +@patch("kairon.shared.data.model_processor.ModelProcessor.is_daily_training_limit_exceeded", autospec=True) +def test_retrain(mock_training_limit): mock_training_limit.return_value = False event_url = urljoin(Utility.environment['events']['server_url'], f"/api/events/execute/{EventClass.model_training}") responses.add( diff --git a/tests/unit_test/data_processor/data_processor_test.py b/tests/unit_test/data_processor/data_processor_test.py index 520df81f7..dbff80a82 100644 --- a/tests/unit_test/data_processor/data_processor_test.py +++ b/tests/unit_test/data_processor/data_processor_test.py @@ -980,7 +980,7 @@ def test_handle_current_model_training_with_in_progress_model_training(self): bot = "test_bot" user = "test_user" ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.INPROGRESS.value) - with pytest.raises(AppException, match="Previous model training in progress. Please wait for sometime."): + with pytest.raises(AppException, match="Previous model training in progress."): ModelProcessor.handle_current_model_training(bot=bot, user=user) @pytest.mark.asyncio