Skip to content

Commit

Permalink
Added api for the aborting the event and added test cases for the same.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mahesh committed Nov 10, 2023
1 parent 28e2df3 commit 521dd7b
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 66 deletions.
17 changes: 8 additions & 9 deletions kairon/api/app/routers/bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from kairon.shared.actions.data_objects import ActionServerLogs
from kairon.shared.auth import Authentication
from kairon.shared.constants import TESTER_ACCESS, DESIGNER_ACCESS, CHAT_ACCESS, UserActivityType, ADMIN_ACCESS, \
VIEW_ACCESS
VIEW_ACCESS, EventClass
from kairon.shared.data.assets_processor import AssetsProcessor
from kairon.shared.data.audit.data_objects import AuditLogData
from kairon.shared.data.constant import EVENT_STATUS, ENDPOINT_TYPE, TOKEN_TYPE, ModelTestType, \
Expand Down Expand Up @@ -511,18 +511,17 @@ async def train(
return {"message": "Model training started."}


@router.post("/retrain", response_model=Response)
async def retrain(
@router.post("/abort/{event_type}", response_model=Response)
async def abort_event(
event_type: EventClass = Path(description="Event type", example=[e.value for e in EventClass]),
current_user: User = Security(Authentication.get_current_user_and_bot, scopes=DESIGNER_ACCESS),
):
"""
Retrains the chatbot
Aborts the event
"""
ModelProcessor.handle_current_model_training(current_user.get_bot(), current_user.get_user())
event = ModelTrainingEvent(current_user.get_bot(), current_user.get_user())
event.validate()
event.enqueue()
return {"message": "Model training started."}
ModelProcessor.abort_current_event(current_user.get_bot(), event_type)

return {"message": f"{event_type} aborted."}


@router.get("/model/reload", response_model=Response)
Expand Down
1 change: 1 addition & 0 deletions kairon/shared/data/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class EVENT_STATUS(str, Enum):
COMPLETED = "Completed"
DONE = "Done"
FAIL = "Fail"
ABORTED = "Aborted"


class ModelTestingLogType(str, Enum):
Expand Down
1 change: 0 additions & 1 deletion kairon/shared/data/data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,6 @@ class ModelTraining(Auditlog):
model_path = StringField(default=None)
exception = StringField(default=None)
model_config = DictField()
task_info = DictField()

meta = {"indexes": [{"fields": ["bot", ("bot", "status", "-start_timestamp")]}]}

Expand Down
45 changes: 30 additions & 15 deletions kairon/shared/data/model_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from mongoengine import Q
from mongoengine.errors import DoesNotExist
from kairon.exceptions import AppException
from kairon.shared.utils import Utility
from .constant import EVENT_STATUS
from .data_objects import ModelTraining, BotSettings
from .data_objects import ModelTraining, BotSettings, ConversationsHistoryDeleteLogs, TrainingDataGenerator
from ..chat.broadcast.data_objects import MessageBroadcastLogs
from ..constants import EventClass
from ..importer.data_objects import ValidationLogs
from ..multilingual.data_objects import BotReplicationLogs
from ..test.data_objects import ModelTestingLogs


class ModelProcessor:
Expand Down Expand Up @@ -55,25 +59,36 @@ def set_training_status(
doc.save()

@staticmethod
def handle_current_model_training(bot: Text, user: Text):
def abort_current_event(bot: Text, event_type: EventClass):
"""
checks if there is any bot training in progress or enqueued
sets event status to aborted if there is any event in progress or enqueued
:param bot: bot id
:param user: user id
:param event_type: type of the event
:return: None
:raises: AppException
"""
model_training = ModelTraining.objects(bot=bot).filter(
Q(status__ne=EVENT_STATUS.DONE.value) & Q(status__ne=EVENT_STATUS.FAIL.value))

if not model_training:
raise AppException("No Enqueued model training present for this bot.")
model_training_object = model_training.get()
if model_training_object.status == EVENT_STATUS.INPROGRESS.value:
raise AppException("Previous model training in progress.")
if model_training_object.status == EVENT_STATUS.ENQUEUED.value:
ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.FAIL.value)
events_dict = {
EventClass.model_training: ModelTraining,
EventClass.model_testing: ModelTestingLogs,
EventClass.delete_history: ConversationsHistoryDeleteLogs,
EventClass.data_importer: ValidationLogs,
EventClass.multilingual: BotReplicationLogs,
EventClass.data_generator: TrainingDataGenerator,
EventClass.faq_importer: ValidationLogs,
EventClass.message_broadcast: MessageBroadcastLogs
}
event_data_object = events_dict.get(event_type)
if event_data_object:
try:
event_object = event_data_object.objects.get(
bot=bot,
status__in=[EVENT_STATUS.INPROGRESS.value, EVENT_STATUS.ENQUEUED.value]
)
event_object.status = EVENT_STATUS.ABORTED.value
event_object.save()
except DoesNotExist:
raise AppException(f"No Enqueued {event_type} present for this bot.")

@staticmethod
def is_training_inprogress(bot: Text, raise_exception=True):
Expand Down
37 changes: 9 additions & 28 deletions tests/integration_test/services_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11899,54 +11899,36 @@ def test_add_training_example_case_insensitivity():


@responses.activate
@patch("kairon.shared.data.model_processor.ModelProcessor.is_daily_training_limit_exceeded", autospec=True)
def test_retrain_with_no_enqueued_model_training(mock_training_limit):
mock_training_limit.return_value = False
event_url = urljoin(Utility.environment['events']['server_url'], f"/api/events/execute/{EventClass.model_training}")
responses.add(
"POST", event_url, json={"success": True, "message": "Event triggered successfully!"}
)
def test_abort_event_with_no_enqueued_model_testing_events():
response = client.post(
f"/api/bot/{pytest.bot}/retrain",
f"/api/bot/{pytest.bot}/abort/model_testing",
headers={"Authorization": pytest.token_type + " " + pytest.access_token},
)
actual = response.json()
print(actual)
assert not actual["success"]
assert actual["error_code"] == 422
assert actual["data"] is None
assert actual["message"] == "No Enqueued model training present for this bot."
assert actual["message"] == "No Enqueued model_testing present for this bot."


@responses.activate
@patch("kairon.shared.data.model_processor.ModelProcessor.handle_current_model_training", autospec=True)
@patch("kairon.shared.data.model_processor.ModelProcessor.is_daily_training_limit_exceeded", autospec=True)
def test_retrain_with_model_training_in_progress(mock_training_limit, mock_handle_current_model_training):
mock_handle_current_model_training.side_effect = AppException("Previous model training in progress.")
mock_training_limit.return_value = False
event_url = urljoin(Utility.environment['events']['server_url'], f"/api/events/execute/{EventClass.model_training}")
responses.add(
"POST", event_url, json={"success": True, "message": "Event triggered successfully!"}
)
def test_abort_event_with_no_enqueued_model_training_events():
response = client.post(
f"/api/bot/{pytest.bot}/train",
headers={"Authorization": pytest.token_type + " " + pytest.access_token},
)
response = client.post(
f"/api/bot/{pytest.bot}/retrain",
f"/api/bot/{pytest.bot}/abort/model_training",
headers={"Authorization": pytest.token_type + " " + pytest.access_token},
)
actual = response.json()
print(actual)
assert not actual["success"]
assert actual["error_code"] == 422
assert actual["data"] is None
assert actual["message"] == "Previous model training in progress."
assert actual["message"] == "No Enqueued model_training present for this bot."


@responses.activate
@patch("kairon.shared.data.model_processor.ModelProcessor.is_daily_training_limit_exceeded", autospec=True)
def test_retrain(mock_training_limit):
def test_abort_event(mock_training_limit):
mock_training_limit.return_value = False
event_url = urljoin(Utility.environment['events']['server_url'], f"/api/events/execute/{EventClass.model_training}")
responses.add(
Expand All @@ -11959,15 +11941,14 @@ def test_retrain(mock_training_limit):
)

response = client.post(
f"/api/bot/{pytest.bot}/retrain",
f"/api/bot/{pytest.bot}/abort/model_training",
headers={"Authorization": pytest.token_type + " " + pytest.access_token},
)
actual = response.json()
assert actual["success"]
assert actual["error_code"] == 0
assert actual["data"] is None
assert actual["message"] == "Model training started."
complete_end_to_end_event_execution(pytest.bot, "[email protected]", EventClass.model_training)
assert actual["message"] == "model_training aborted."


def test_add_utterances_case_insensitivity():
Expand Down
41 changes: 28 additions & 13 deletions tests/unit_test/data_processor/data_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,26 +962,41 @@ def test_edit_bot_settings(self):
assert updated_settings.analytics.to_mongo().to_dict() == {'fallback_intent': 'utter_please_rephrase'}
assert updated_settings.llm_settings.to_mongo().to_dict() == {'enable_faq': False, 'provider': 'openai'}

def test_handle_current_model_training(self):
def test_abort_current_event_with_no_model_training_event(self):
bot = "test_bot"
user = "test_user"
ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.ENQUEUED.value)
ModelProcessor.handle_current_model_training(bot=bot, user=user)
model_training_object = ModelTraining.objects(bot=bot).get()
assert model_training_object.status == EVENT_STATUS.FAIL.value
with pytest.raises(AppException, match="No Enqueued model_training present for this bot."):
ModelProcessor.abort_current_event(bot=bot, event_type="model_training")

def test_handle_current_model_training_with_no_enqueued_model_training(self):
def test_abort_current_event_with_no_model_testing_event(self):
bot = "test_bot"
with pytest.raises(AppException, match="No Enqueued model_testing present for this bot."):
ModelProcessor.abort_current_event(bot=bot, event_type="model_testing")

def test_abort_current_event_with_no_delete_history_event(self):
bot = "test_bot"
with pytest.raises(AppException, match="No Enqueued delete_history present for this bot."):
ModelProcessor.abort_current_event(bot=bot, event_type="delete_history")

def test_abort_current_event_with_no_data_importer_event(self):
bot = "test_bot"
with pytest.raises(AppException, match="No Enqueued data_importer present for this bot."):
ModelProcessor.abort_current_event(bot=bot, event_type="data_importer")

def test_abort_current_event_with_model_training(self):
bot = "test_bot"
user = "test_user"
with pytest.raises(AppException, match="No Enqueued model training present for this bot."):
ModelProcessor.handle_current_model_training(bot=bot, user=user)
ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.ENQUEUED.value)
ModelProcessor.abort_current_event(bot=bot, event_type="model_training")
model_training_object = ModelTraining.objects(bot=bot).get()
assert model_training_object.status == EVENT_STATUS.ABORTED.value

def test_handle_current_model_training_with_in_progress_model_training(self):
def test_abort_current_event_with_data_generator(self):
bot = "test_bot"
user = "test_user"
ModelProcessor.set_training_status(bot=bot, user=user, status=EVENT_STATUS.INPROGRESS.value)
with pytest.raises(AppException, match="Previous model training in progress."):
ModelProcessor.handle_current_model_training(bot=bot, user=user)
TrainingDataGenerationProcessor.set_status(bot=bot, user=user, status=EVENT_STATUS.ENQUEUED.value)
ModelProcessor.abort_current_event(bot=bot, event_type="data_generator")
data_generator_object = TrainingDataGenerator.objects(bot=bot).get()
assert data_generator_object.status == EVENT_STATUS.ABORTED.value

@pytest.mark.asyncio
async def test_save_from_path_yml(self):
Expand Down

0 comments on commit 521dd7b

Please sign in to comment.