Skip to content

Commit

Permalink
Added api for the retraining of the bot and added test cases for the …
Browse files Browse the repository at this point in the history
…same.
  • Loading branch information
Mahesh committed Nov 8, 2023
1 parent 6a60039 commit 0e9357f
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 0 deletions.
14 changes: 14 additions & 0 deletions kairon/api/app/routers/bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,20 @@ async def train(
return {"message": "Model training started."}


@router.post("/retrain", response_model=Response)
async def retrain(
current_user: User = Security(Authentication.get_current_user_and_bot, scopes=DESIGNER_ACCESS),
):
"""
Retrains the chatbot
"""
ModelProcessor.handle_current_model_training(current_user.get_bot(), current_user.get_user())
event = ModelTrainingEvent(current_user.get_bot(), current_user.get_user())
event.validate()
event.enqueue()
return {"message": "Model training started."}


@router.get("/model/reload", response_model=Response)
async def reload_model(
current_user: User = Security(Authentication.get_current_user_and_bot, scopes=TESTER_ACCESS),
Expand Down
1 change: 1 addition & 0 deletions kairon/shared/data/data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ class ModelTraining(Auditlog):
model_path = StringField(default=None)
exception = StringField(default=None)
model_config = DictField()
task_info = DictField()

meta = {"indexes": [{"fields": ["bot", ("bot", "status", "-start_timestamp")]}]}

Expand Down
21 changes: 21 additions & 0 deletions kairon/shared/data/model_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ def set_training_status(
doc.exception = exception
doc.save()

@staticmethod
def handle_current_model_training(bot: Text, user: Text):
"""
checks if there is any bot training in progress or enqueued
:param bot: bot id
:param user: user id
:return: None
:raises: AppException
"""
model_training = ModelTraining.objects(bot=bot).filter(
Q(status__ne=EVENT_STATUS.DONE.value) & Q(status__ne=EVENT_STATUS.FAIL.value))

if not model_training:
raise AppException("No Enqueued model training present for this bot.")
model_training_object = model_training.get()
if model_training_object.status == EVENT_STATUS.INPROGRESS.value:
raise AppException("Previous model training in progress. Please wait for sometime.")
if model_training_object.status == EVENT_STATUS.ENQUEUED.value:
ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.FAIL.value)

@staticmethod
def is_training_inprogress(bot: Text, raise_exception=True):
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/integration_test/services_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11898,6 +11898,33 @@ def test_add_training_example_case_insensitivity():
assert actual["error_code"] == 0


@responses.activate
@patch("kairon.shared.data.model_processor.ModelProcessor.is_daily_training_limit_exceeded", autospec=True)
def test_retrain(mock_training_limit):

mock_training_limit.return_value = False
event_url = urljoin(Utility.environment['events']['server_url'], f"/api/events/execute/{EventClass.model_training}")
responses.add(
"POST", event_url, json={"success": True, "message": "Event triggered successfully!"}
)

client.post(
f"/api/bot/{pytest.bot}/train",
headers={"Authorization": pytest.token_type + " " + pytest.access_token},
)

response = client.post(
f"/api/bot/{pytest.bot}/retrain",
headers={"Authorization": pytest.token_type + " " + pytest.access_token},
)
actual = response.json()
assert actual["success"]
assert actual["error_code"] == 0
assert actual["data"] is None
assert actual["message"] == "Model training started."
complete_end_to_end_event_execution(pytest.bot, "[email protected]", EventClass.model_training)


def test_add_utterances_case_insensitivity():
response = client.post(
f"/api/bot/{pytest.bot}/utterance",
Expand Down
21 changes: 21 additions & 0 deletions tests/unit_test/data_processor/data_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,27 @@ def test_edit_bot_settings(self):
assert updated_settings.analytics.to_mongo().to_dict() == {'fallback_intent': 'utter_please_rephrase'}
assert updated_settings.llm_settings.to_mongo().to_dict() == {'enable_faq': False, 'provider': 'openai'}

def test_handle_current_model_training(self):
bot = "test_bot"
user = "test_user"
ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.ENQUEUED.value)
ModelProcessor.handle_current_model_training(bot=bot, user=user)
model_training_object = ModelTraining.objects(bot=bot).get()
assert model_training_object.status == EVENT_STATUS.FAIL.value

def test_handle_current_model_training_with_no_enqueued_model_training(self):
bot = "test_bot"
user = "test_user"
with pytest.raises(AppException, match="No Enqueued model training present for this bot."):
ModelProcessor.handle_current_model_training(bot=bot, user=user)

def test_handle_current_model_training_with_in_progress_model_training(self):
bot = "test_bot"
user = "test_user"
ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.INPROGRESS.value)
with pytest.raises(AppException, match="Previous model training in progress. Please wait for sometime."):
ModelProcessor.handle_current_model_training(bot=bot, user=user)

@pytest.mark.asyncio
async def test_save_from_path_yml(self):
processor = MongoProcessor()
Expand Down

0 comments on commit 0e9357f

Please sign in to comment.