Skip to content

Commit

Permalink
Added api for the retraining of the bot and added test cases for the …
Browse files Browse the repository at this point in the history
…same.
  • Loading branch information
Mahesh committed Nov 9, 2023
1 parent 0e9357f commit a0cde53
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
2 changes: 1 addition & 1 deletion kairon/shared/data/model_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def handle_current_model_training(bot: Text, user: Text):
raise AppException("No Enqueued model training present for this bot.")
model_training_object = model_training.get()
if model_training_object.status == EVENT_STATUS.INPROGRESS.value:
raise AppException("Previous model training in progress. Please wait for sometime.")
raise AppException("Previous model training in progress.")
if model_training_object.status == EVENT_STATUS.ENQUEUED.value:
ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.FAIL.value)

Expand Down
47 changes: 46 additions & 1 deletion tests/integration_test/services_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11900,8 +11900,53 @@ def test_add_training_example_case_insensitivity():

@responses.activate
@patch("kairon.shared.data.model_processor.ModelProcessor.is_daily_training_limit_exceeded", autospec=True)
def test_retrain(mock_training_limit):
def test_retrain_with_no_enqueued_model_training(mock_training_limit):
mock_training_limit.return_value = False
event_url = urljoin(Utility.environment['events']['server_url'], f"/api/events/execute/{EventClass.model_training}")
responses.add(
"POST", event_url, json={"success": True, "message": "Event triggered successfully!"}
)
response = client.post(
f"/api/bot/{pytest.bot}/retrain",
headers={"Authorization": pytest.token_type + " " + pytest.access_token},
)
actual = response.json()
print(actual)
assert not actual["success"]
assert actual["error_code"] == 422
assert actual["data"] is None
assert actual["message"] == "No Enqueued model training present for this bot."


@responses.activate
@patch("kairon.shared.data.model_processor.ModelProcessor.handle_current_model_training", autospec=True)
@patch("kairon.shared.data.model_processor.ModelProcessor.is_daily_training_limit_exceeded", autospec=True)
def test_retrain_with_model_training_in_progress(mock_training_limit, mock_handle_current_model_training):
mock_handle_current_model_training.side_effect = AppException("Previous model training in progress.")
mock_training_limit.return_value = False
event_url = urljoin(Utility.environment['events']['server_url'], f"/api/events/execute/{EventClass.model_training}")
responses.add(
"POST", event_url, json={"success": True, "message": "Event triggered successfully!"}
)
response = client.post(
f"/api/bot/{pytest.bot}/train",
headers={"Authorization": pytest.token_type + " " + pytest.access_token},
)
response = client.post(
f"/api/bot/{pytest.bot}/retrain",
headers={"Authorization": pytest.token_type + " " + pytest.access_token},
)
actual = response.json()
print(actual)
assert not actual["success"]
assert actual["error_code"] == 422
assert actual["data"] is None
assert actual["message"] == "Previous model training in progress."


@responses.activate
@patch("kairon.shared.data.model_processor.ModelProcessor.is_daily_training_limit_exceeded", autospec=True)
def test_retrain(mock_training_limit):
mock_training_limit.return_value = False
event_url = urljoin(Utility.environment['events']['server_url'], f"/api/events/execute/{EventClass.model_training}")
responses.add(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_test/data_processor/data_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ def test_handle_current_model_training_with_in_progress_model_training(self):
bot = "test_bot"
user = "test_user"
ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.INPROGRESS.value)
with pytest.raises(AppException, match="Previous model training in progress. Please wait for sometime."):
with pytest.raises(AppException, match="Previous model training in progress."):
ModelProcessor.handle_current_model_training(bot=bot, user=user)

@pytest.mark.asyncio
Expand Down

0 comments on commit a0cde53

Please sign in to comment.