Skip to content

Commit

Permalink
Refactor all the things
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertrand committed Dec 23, 2024
1 parent b3e8ea7 commit b96c667
Show file tree
Hide file tree
Showing 25 changed files with 1,473 additions and 217 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ RUN apt-get clean && apt-get purge

USER mitodl

EXPOSE 8888
EXPOSE 8001
ENV PORT 8001
File renamed without changes.
49 changes: 49 additions & 0 deletions ai_chatbots/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""AI-specific functions for ai_agents."""

from typing import Optional

from django.conf import settings
from llama_index.core.agent import AgentRunner
from llama_index.core.llms.llm import LLM

from ai_chatbots.constants import AgentClassEnum, LLMClassEnum
from ai_chatbots.proxies import AIProxy


def get_llm(model_name: Optional[str] = None, proxy: Optional[AIProxy] = None) -> LLM:
"""
Get the LLM from the given model name,
incorporating a proxy if passed.
Args:
model_name: The name of the model
proxy: The proxy to use
Returns:
The LLM
"""
if not model_name:
model_name = settings.AI_MODEL
try:
llm_class = LLMClassEnum[settings.AI_PROVIDER].value
return llm_class(
model=model_name,
**(proxy.get_api_kwargs() if proxy else {}),
additional_kwargs=(proxy.get_additional_kwargs() if proxy else {}),
)
except KeyError as ke:
msg = f"{settings.AI_PROVIDER} not supported"
raise NotImplementedError(msg) from ke
except Exception as ex:
msg = f"Error instantiating LLM: {model_name}"
raise ValueError(msg) from ex


def get_agent() -> AgentRunner:
"""Get the appropriate chatbot agent for the AI provider"""
try:
return AgentClassEnum[settings.AI_PROVIDER].value
except KeyError as ke:
msg = f"{settings.AI_PROVIDER} not supported"
raise NotImplementedError(msg) from ke
2 changes: 1 addition & 1 deletion ai_agents/apps.py → ai_chatbots/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
class AiChatConfig(AppConfig):
"""AI Chat Appconfig"""

name = "ai_agents"
name = "ai_chatbots"
121 changes: 46 additions & 75 deletions ai_agents/agents.py → ai_chatbots/chatbots.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,23 @@
from django.conf import settings
from django.core.cache import caches
from django.utils.module_loading import import_string
from llama_index.agent.openai import OpenAIAgent
from llama_index.core.agent import AgentRunner
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.tools import FunctionTool, ToolMetadata
from llama_index.llms.openai import OpenAI
from openai import BadRequestError
from pydantic import Field

from ai_agents.constants import AIModelAPI, LearningResourceType, OfferedBy
from ai_agents.utils import enum_zip
from ai_chatbots.api import get_agent, get_llm
from ai_chatbots.constants import LearningResourceType, OfferedBy
from ai_chatbots.utils import enum_zip

log = logging.getLogger(__name__)


class BaseChatAgent(ABC):
class BaseChatbot(ABC):
"""
Base service class for an AI chat agent
Base AI chatbot class
Llamaindex was chosen to implement this because it provides
a far easier framework than native OpenAi or LiteLLM to
Expand All @@ -42,31 +41,31 @@ class BaseChatAgent(ABC):
https://docs.litellm.ai/docs/completion/function_call
"""

INSTRUCTIONS = "Provide instructions for the AI assistant"
INSTRUCTIONS = "You are a friendly chatbot, answer the user's questions"

# For LiteLLM tracking purposes
JOB_ID = "BASECHAT_JOB"
TASK_NAME = "BASECHAT_TASK"
CACHE_PREFIX = "base_ai_"

def __init__(
self,
user_id: str,
*,
name: str = "AI Chat Agent",
name: str = "MIT Open Learning Chatbot",
model: Optional[str] = None,
temperature: Optional[float] = None,
instructions: Optional[str] = None,
):
"""Initialize the AI chat agent service"""
"""Initialize the AI chatbot"""
self.user_id = user_id
self.assistant_name = name
self.ai = settings.AI_MODEL_API
self.model = model or settings.AI_MODEL
self.temperature = temperature or DEFAULT_TEMPERATURE
self.instructions = instructions or self.INSTRUCTIONS
if settings.AI_PROXY_CLASS and settings.AI_PROXY_URL:
self.proxy = import_string(f"ai_agents.proxy.{settings.AI_PROXY_CLASS}")()
self.proxy = import_string(f"ai_chatbots.proxy.{settings.AI_PROXY_CLASS}")(
user_id=user_id, task_id=self.TASK_NAME
)
else:
self.proxy = None
self.agent = None
Expand Down Expand Up @@ -104,22 +103,14 @@ def save_chat_history(self) -> None:
self.cache_key, json.dumps(chat_history), timeout=settings.AI_CACHE_TIMEOUT
)

@abstractmethod
def create_agent(self) -> AgentRunner:
"""Create an AgentRunner for the relevant AI source"""
if self.ai == AIModelAPI.openai.value:
return self.create_openai_agent()
else:
error = f"AI source {self.ai} is not supported"
raise NotImplementedError(error)

def create_tools(self):
"""Create any tools required by the agent"""
return []

@abstractmethod
def create_openai_agent(self) -> OpenAIAgent:
"""Create an OpenAI agent"""

def clear_chat_history(self) -> None:
"""Clear the chat history from the cache"""
self.agent.chat_history.clear()
Expand Down Expand Up @@ -170,13 +161,36 @@ def get_completion(self, message: str, *, debug: bool = settings.AI_DEBUG) -> st
self.save_chat_history()


class RecommendationAgent(BaseChatAgent):
"""Service class for the AI search function agent"""
class FunctionCallingChatbot(BaseChatbot):
"""Function calling chatbot, using a FunctionCallingAgent"""

TASK_NAME = "FUNCTION_CALL_TASK"

JOB_ID = "SEARCH_JOB"
TASK_NAME = "SEARCH_TASK"
def create_agent(self) -> AgentRunner:
"""
Create a function calling agent
"""
llm = get_llm(self.model, self.proxy)
self.agent = get_agent().from_tools(
tools=self.create_tools(),
llm=llm,
verbose=True,
system_prompt=self.instructions,
)
if self.save_history:
self.get_or_create_chat_history_cache()
return self.agent


class ResourceRecommendationBot(FunctionCallingChatbot):
"""
Chatbot that searches for learning resources in the MIT Learn catalog,
then recommends the best results to the user based on their query.
"""

INSTRUCTIONS = f"""You are an assistant helping users find courses from a catalog
TASK_NAME = "RECOMMENDATION_TASK"

INSTRUCTIONS = """You are an assistant helping users find courses from a catalog
of learning resources. Users can ask about specific topics, levels, or recommendations
based on their interests or goals.
Expand Down Expand Up @@ -205,8 +219,7 @@ class RecommendationAgent(BaseChatAgent):
as the value for this parameter.
offered_by: If a user asks for resources "offered by" or "from" an institution,
you should include this parameter based on the following
dictionary: {OfferedBy.as_dict()} DO NOT USE THE offered_by FILTER OTHERWISE.
you should include this parameter. DO NOT USE THE offered_by FILTER OTHERWISE.
certificate: true if the user is interested in resources that offer certificates, false
if the user does not want resources with a certificate offered. Do not used this filter
Expand Down Expand Up @@ -248,21 +261,6 @@ class RecommendationAgent(BaseChatAgent):
Expected Output: Maybe ask whether the user wants to learn how to program, or just use
AI in their discipline - does this person want to study machine learning? More info
needed. Then perform a relevant search and send back the best results.
And here are some recommended search parameters to apply for sample user prompts:
User: "I am interested in learning advanced AI techniques"
Search parameters: {{"q": "AI techniques"}}
User: "I am curious about AI applications for business"
Search parameters: {{"q": "AI business"}}
User: "I want free basic courses about climate change from OpenCourseware"
Search parameters: {{"q": "climate change", "free": true, "resource_type": ["course"],
"offered_by": "ocw"}}
User: "I want to learn some advanced mathematics"
Search parameters: {{"q": "mathematics"}}
"""

class SearchToolSchema(pydantic.BaseModel):
Expand All @@ -272,7 +270,7 @@ class SearchToolSchema(pydantic.BaseModel):
q: The search query string
resource_type: Filter by type of resource (course, program, etc)
free: Filter for free resources only
certificate: Filter for resources offering certificates
certification: Filter for resources offering certificates
offered_by: Filter by institution offering the resource
"""

Expand All @@ -291,7 +289,7 @@ class SearchToolSchema(pydantic.BaseModel):
default=None,
description="Whether the resource is free to access, true|false",
)
certificate: Optional[bool] = Field(
certification: Optional[bool] = Field(
default=None,
description=(
"Whether the resource offers a certificate upon completion, true|false"
Expand All @@ -309,7 +307,7 @@ class SearchToolSchema(pydantic.BaseModel):
"q": "machine learning",
"resource_type": ["course"],
"free": True,
"certificate": False,
"certification": False,
"offered_by": "MIT",
}
]
Expand All @@ -325,7 +323,7 @@ def __init__(
temperature: Optional[float] = None,
instructions: Optional[str] = None,
):
"""Initialize the AI search agent service"""
"""Initialize the chatbot"""
super().__init__(
user_id,
name=name,
Expand All @@ -335,7 +333,7 @@ def __init__(
)
self.search_parameters = []
self.search_results = []
self.create_agent()
super().create_agent()

def search_courses(self, q: str, **kwargs) -> str:
"""
Expand Down Expand Up @@ -392,33 +390,6 @@ def search_courses(self, q: str, **kwargs) -> str:
log.exception("Error querying MIT API")
return json.dumps({"error": str(e)})

def create_openai_agent(self) -> OpenAIAgent:
"""
Create an OpenAI-specific llamaindex agent for function calling
Using `OpenAI` instead of a more universal `LiteLLM` because
the `LiteLLM` class as implemented by llamaindex does not
support function calling. ie:
agent = FunctionCallingAgentWorker.from_tools(....
> AssertionError: llm must be an instance of FunctionCallingLLM
"""
llm = OpenAI(
model=self.model,
**(self.proxy.get_api_kwargs() if self.proxy else {}),
additional_kwargs=(
self.proxy.get_additional_kwargs(self) if self.proxy else {}
),
)
self.agent = OpenAIAgent.from_tools(
tools=self.create_tools(),
llm=llm,
verbose=True,
system_prompt=self.instructions,
)
if self.save_history:
self.get_or_create_chat_history_cache()
return self.agent

def create_tools(self):
"""Create tools required by the agent"""
return [self.create_search_tool()]
Expand Down
Loading

0 comments on commit b96c667

Please sign in to comment.