Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reload Model Changes #1115

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions kairon/chat/agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from rasa.core.agent import Agent
from rasa.core.lock_store import LockStore

from kairon.shared.chat.cache.in_memory_agent import AgentCache
from kairon.exceptions import AppException
from kairon.shared.chat.cache.in_memory_agent import AgentCache
from kairon.shared.chat.cache.in_memory_agent import InMemoryAgentCache
from kairon.shared.data.processor import MongoProcessor
from .agent.agent import KaironAgent
from kairon.shared.chat.cache.in_memory_agent import InMemoryAgentCache
from ..shared.account.activity_log import UserActivityLogger
from ..shared.constants import UserActivityType
from ..shared.utils import Utility


Expand All @@ -36,11 +38,12 @@ def get_agent(bot: Text) -> Agent:
return AgentProcessor.cache_provider.get(bot)

@staticmethod
def reload(bot: Text):
def reload(bot: Text, email: str = None):
"""
reload bot agent

:param bot: bot id
:param email: email
:return: None
"""
try:
Expand All @@ -57,7 +60,13 @@ def reload(bot: Text):
lock_store=lock_store_endpoint)
agent.model_ver = model_path.split("/")[-1]
AgentProcessor.cache_provider.set(bot, agent, is_billed=bot_settings.is_billed)
UserActivityLogger.add_log(a_type=UserActivityType.reload_model_completion.value, email=email,
bot=bot, message=['Model reload completed!'],
data={"username": email, "process_id": os.getpid()})
except Exception as e:
UserActivityLogger.add_log(a_type=UserActivityType.reload_model_failure.value, email=email,
bot=bot, message=['Model reload failed!'],
data={"username": email, "process_id": os.getpid()})
logging.exception(e)
raise AppException("Bot has not been trained yet!")

Expand Down
6 changes: 3 additions & 3 deletions kairon/chat/routers/web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from kairon.shared.data.processor import MongoProcessor
from kairon.shared.models import User


router = APIRouter()


Expand Down Expand Up @@ -77,5 +76,6 @@ async def reload_model(
"""
Retrieves chat client config of a bot.
"""
background_tasks.add_task(ChatUtils.reload, bot)
return {"message": "Reloading Model!"}
if not ChatUtils.is_reload_model_in_progress(bot):
background_tasks.add_task(ChatUtils.reload, bot, current_user.email)
return {"message": "Reloading Model!"}
41 changes: 39 additions & 2 deletions kairon/chat/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import datetime
import json
import os
from typing import Text, Dict

from loguru import logger
from mongoengine import DoesNotExist
from pymongo.collection import Collection
from pymongo.errors import ServerSelectionTimeoutError
from rasa.core.channels import UserMessage
from rasa.core.tracker_store import TrackerStore

from .agent_processor import AgentProcessor
from .. import Utility
from ..exceptions import AppException
from ..live_agent.factory import LiveAgentFactory
from ..shared.account.activity_log import UserActivityLogger
from ..shared.actions.utils import ActionUtility
from ..shared.constants import UserActivityType
from ..shared.data.audit.data_objects import AuditLogData
from ..shared.data.constant import AuditlogActions
from ..shared.live_agent.processor import LiveAgentsProcessor
from ..shared.metering.constants import MetricType
from ..shared.metering.metering_processor import MeteringProcessor
Expand All @@ -29,8 +36,38 @@ async def chat(data: Text, account: int, bot: Text, user: Text, is_integration_u
return chat_response

@staticmethod
def reload(bot: Text):
AgentProcessor.reload(bot)
def is_reload_model_in_progress(bot: str, raise_exception=True):
"""
Checks if model reloading is in progress.
@param bot: bot id
@param raise_exception: Raise exception if event is in progress.
@return: boolean flag.
"""
in_progress = False
try:
latest_log = AuditLogData.objects(attributes__key='bot', attributes__value=bot, action=AuditlogActions.ACTIVITY.value).filter(
entity=UserActivityType.reload_model_enqueued.value
).order_by('-timestamp').first()
if latest_log:
if raise_exception:
raise AppException("Model reload enqueued. Check logs.")
in_progress = True
except DoesNotExist as e:
logger.error(e)
return in_progress

@staticmethod
def reload(bot: Text, email: str):
exc = None
try:
AgentProcessor.reload(bot, email)
except Exception as e:
logger.exception(e)
exc = str(e)
finally:
UserActivityLogger.add_log(a_type=UserActivityType.reload_model_enqueued.value,
email=email, bot=bot, message=['Model reload enqueued!'],
data={"username": email, "process_id": os.getpid(), "exception": exc})

@staticmethod
def __attach_agent_handoff_metadata(account: int, bot: Text, sender_id: Text, bot_predictions, tracker):
Expand Down
3 changes: 3 additions & 0 deletions kairon/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class UserActivityType(str, Enum):
login_refresh_token = "login_refresh_token"
invalid_login = 'invalid_login'
template_creation = 'template_creation'
reload_model_enqueued = 'reload_model_enqueued'
reload_model_failure = 'reload_model_failure'
reload_model_completion = 'reload_model_completion'


class EventClass(str, Enum):
Expand Down
3 changes: 2 additions & 1 deletion kairon/shared/data/audit/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from kairon import Utility
from kairon.exceptions import AppException
from kairon.shared.account.data_objects import Bot
from kairon.shared.data.audit.data_objects import AuditLogData
from kairon.shared.data.constant import AuditlogActions
from kairon.shared.data.data_objects import EventConfig
Expand All @@ -29,7 +30,7 @@ def log(entity, account: int = None, bot: Text = None, email: Text = None, data:

action = kwargs.get("action")
attribute = AuditDataProcessor.get_attributes({"bot": bot, "account": account, "email": email})
user = email if email else AccountProcessor.get_account(account)['user']
user = email if email else AccountProcessor.get_account(account)['user'] if account else AccountProcessor.get_bot(bot)['user']
audit_log = AuditLogData(attributes=attribute,
user=user,
action=action,
Expand Down
55 changes: 53 additions & 2 deletions tests/integration_test/chat_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import time
from datetime import datetime
from unittest import mock
from urllib.parse import urlencode, quote_plus

Expand All @@ -27,7 +28,8 @@
from kairon.shared.auth import Authentication
from kairon.shared.chat.processor import ChatDataProcessor
from kairon.shared.constants import UserActivityType
from kairon.shared.data.constant import INTEGRATION_STATUS
from kairon.shared.data.audit.data_objects import AuditLogData
from kairon.shared.data.constant import INTEGRATION_STATUS, AuditlogActions
from kairon.shared.data.constant import TOKEN_TYPE
from kairon.shared.data.data_objects import BotSettings
from kairon.shared.data.processor import MongoProcessor
Expand Down Expand Up @@ -797,7 +799,9 @@ def test_chat_limited_access_prevent_chat():


@patch('kairon.shared.utils.Utility.get_local_mongo_store')
def test_reload(mock_store):
@patch('kairon.chat.utils.ChatUtils.is_reload_model_in_progress')
def test_reload(mock_reload_model_in_progress, mock_store):
mock_reload_model_in_progress.return_value = False
mock_store.return_value = None
response = client.get(
f"/api/bot/{bot}/reload",
Expand Down Expand Up @@ -839,6 +843,53 @@ def test_reload_exception():
assert actual["message"] == "Not authenticated"


@patch('kairon.shared.utils.Utility.get_local_mongo_store')
def test_reload_event_already_in_progress(mock_store):
mock_store.return_value = None
AuditLogData(
attributes=[{"key": "bot", "value": bot}], user="test", timestamp=datetime.utcnow(),
action=AuditlogActions.ACTIVITY.value,
entity=UserActivityType.reload_model_enqueued.value,
data={'message': 'Model reload enqueued!'}
).save()
response = client.get(
f"/api/bot/{bot}/reload",
headers={
"Authorization": token_type + " " + token
},
)
actual = response.json()
print(actual)
assert not actual["success"]
assert actual["error_code"] == 422
assert actual["data"] is None
assert actual["message"] == 'Model reload enqueued. Check logs.'


@patch('kairon.shared.utils.Utility.get_local_mongo_store')
@patch('kairon.chat.utils.ChatUtils.is_reload_model_in_progress')
def test_reload_event_exception_in_reload(mock_reload_model_in_progress, mock_store):
mock_reload_model_in_progress.return_value = False
mock_store.return_value = None
with patch("kairon.chat.utils.ChatUtils.reload") as mock_reload:
mock_reload.side_effect = Exception("Simulated exception during model reload")
responses.add(
responses.GET,
f"/api/bot/{bot}/reload",
status=200,
json={'success': True, 'error_code': 0, "data": None, 'message': "Reloading Model!"}
)
try:
response = client.get(
f"/api/bot/{bot}/reload",
headers={
"Authorization": token_type + " " + token
},
)
except Exception as e:
assert e


@patch('kairon.chat.handlers.channels.slack.SlackHandler.is_request_from_slack_authentic')
@patch('kairon.shared.utils.Utility.get_local_mongo_store')
def test_slack_auth_bot_challenge(mock_store, mock_slack):
Expand Down
Loading
Loading