Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline: Use capabilities instead of hardcoded LLMs #95

Merged
merged 8 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..common.singleton import Singleton
from ..common.message_converters import (
from app.common.singleton import Singleton
from app.common.message_converters import (
convert_iris_message_to_langchain_message,
convert_langchain_message_to_iris_message,
)
57 changes: 50 additions & 7 deletions app/common/message_converters.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
from datetime import datetime

from langchain_core.messages import BaseMessage
from ..domain.iris_message import IrisMessage, IrisMessageRole

from app.domain.data.text_message_content_dto import TextMessageContentDTO
from app.domain.pyris_message import PyrisMessage, IrisMessageRole


def convert_iris_message_to_langchain_message(iris_message: IrisMessage) -> BaseMessage:
match iris_message.role:
def convert_iris_message_to_langchain_message(
iris_message: PyrisMessage,
) -> BaseMessage:
match iris_message.sender:
case IrisMessageRole.USER:
role = "human"
case IrisMessageRole.ASSISTANT:
role = "ai"
case IrisMessageRole.SYSTEM:
role = "system"
case _:
raise ValueError(f"Unknown message role: {iris_message.role}")
return BaseMessage(content=iris_message.text, type=role)
raise ValueError(f"Unknown message role: {iris_message.sender}")
if len(iris_message.contents) == 0:
raise ValueError("IrisMessage contents must not be empty")
message = iris_message.contents[0]
# Check if the message is of type TextMessageContentDTO
if not isinstance(message, TextMessageContentDTO):
raise ValueError("Message must be of type TextMessageContentDTO")
return BaseMessage(content=message.text_content, type=role)


def convert_langchain_message_to_iris_message(base_message: BaseMessage) -> IrisMessage:
def convert_langchain_message_to_iris_message(
base_message: BaseMessage,
) -> PyrisMessage:
match base_message.type:
case "human":
role = IrisMessageRole.USER
Expand All @@ -25,4 +39,33 @@ def convert_langchain_message_to_iris_message(base_message: BaseMessage) -> Iris
role = IrisMessageRole.SYSTEM
case _:
raise ValueError(f"Unknown message type: {base_message.type}")
return IrisMessage(text=base_message.content, role=role)
contents = [TextMessageContentDTO(textContent=base_message.content)]
return PyrisMessage(
contents=contents,
sender=role,
send_at=datetime.now(),
)


def map_role_to_str(role: IrisMessageRole) -> str:
match role:
case IrisMessageRole.USER:
return "user"
case IrisMessageRole.ASSISTANT:
return "assistant"
case IrisMessageRole.SYSTEM:
return "system"
case _:
raise ValueError(f"Unknown message role: {role}")


def map_str_to_role(role: str) -> IrisMessageRole:
match role:
case "user":
return IrisMessageRole.USER
case "assistant":
return IrisMessageRole.ASSISTANT
case "system":
return IrisMessageRole.SYSTEM
case _:
raise ValueError(f"Unknown message role: {role}")
4 changes: 2 additions & 2 deletions app/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .error_response_dto import IrisErrorResponseDTO
from .pipeline_execution_dto import PipelineExecutionDTO
from .pipeline_execution_settings_dto import PipelineExecutionSettingsDTO
from ..domain.tutor_chat.tutor_chat_pipeline_execution_dto import (
from app.domain.tutor_chat.tutor_chat_pipeline_execution_dto import (
TutorChatPipelineExecutionDTO,
)
from .iris_message import IrisMessage, IrisMessageRole
from .pyris_message import PyrisMessage, IrisMessageRole
2 changes: 1 addition & 1 deletion app/domain/data/feedback_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class FeedbackDTO(BaseModel):
text: Optional[str] = None
test_case_name: str = Field(alias="testCaseName")
test_case_name: Optional[str] = Field(alias="testCaseName", default=None)
credits: float

def __str__(self):
Expand Down
51 changes: 0 additions & 51 deletions app/domain/data/message_dto.py

This file was deleted.

2 changes: 1 addition & 1 deletion app/domain/data/programming_exercise_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ProgrammingLanguage(str, Enum):
class ProgrammingExerciseDTO(BaseModel):
id: int
name: str
programming_language: ProgrammingLanguage = Field(alias="programmingLanguage")
programming_language: Optional[str] = Field(alias="programmingLanguage")
template_repository: Dict[str, str] = Field(alias="templateRepository")
solution_repository: Dict[str, str] = Field(alias="solutionRepository")
test_repository: Dict[str, str] = Field(alias="testRepository")
Expand Down
17 changes: 0 additions & 17 deletions app/domain/iris_message.py

This file was deleted.

4 changes: 2 additions & 2 deletions app/domain/pipeline_execution_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from pydantic import BaseModel, Field

from ..domain.pipeline_execution_settings_dto import PipelineExecutionSettingsDTO
from ..domain.status.stage_dto import StageDTO
from app.domain.pipeline_execution_settings_dto import PipelineExecutionSettingsDTO
from app.domain.status.stage_dto import StageDTO


class PipelineExecutionDTO(BaseModel):
Expand Down
22 changes: 22 additions & 0 deletions app/domain/pyris_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from datetime import datetime
from enum import Enum
from typing import List

from pydantic import BaseModel, Field

from app.domain.data.message_content_dto import MessageContentDTO


class IrisMessageRole(str, Enum):
USER = "USER"
ASSISTANT = "LLM"
SYSTEM = "SYSTEM"


class PyrisMessage(BaseModel):
sent_at: datetime | None = Field(alias="sentAt", default=None)
sender: IrisMessageRole
contents: List[MessageContentDTO] = []

def __str__(self):
return f"{self.sender.lower()}: {self.contents}"
4 changes: 2 additions & 2 deletions app/domain/tutor_chat/tutor_chat_pipeline_execution_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from pydantic import Field

from ...domain.pyris_message import PyrisMessage
from ...domain import PipelineExecutionDTO
from ...domain.data.course_dto import CourseDTO
from ...domain.data.message_dto import MessageDTO
from ...domain.data.programming_exercise_dto import ProgrammingExerciseDTO
from ...domain.data.user_dto import UserDTO
from ...domain.data.submission_dto import SubmissionDTO
Expand All @@ -14,5 +14,5 @@ class TutorChatPipelineExecutionDTO(PipelineExecutionDTO):
submission: Optional[SubmissionDTO] = None
exercise: ProgrammingExerciseDTO
course: CourseDTO
chat_history: List[MessageDTO] = Field(alias="chatHistory", default=[])
chat_history: List[PyrisMessage] = Field(alias="chatHistory", default=[])
user: Optional[UserDTO] = None
6 changes: 3 additions & 3 deletions app/llm/external/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from pydantic import BaseModel

from ...domain import IrisMessage
from ...domain import PyrisMessage
from ...llm import CompletionArguments
from ...llm.capability import CapabilityList

Expand Down Expand Up @@ -39,8 +39,8 @@ def __subclasshook__(cls, subclass) -> bool:

@abstractmethod
def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
self, messages: list[PyrisMessage], arguments: CompletionArguments
) -> PyrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError(
f"The LLM {self.__str__()} does not support chat completion"
Expand Down
26 changes: 19 additions & 7 deletions app/llm/external/ollama.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
from datetime import datetime
from typing import Literal, Any

from ollama import Client, Message

from ...domain import IrisMessage, IrisMessageRole
from ...common.message_converters import map_role_to_str, map_str_to_role
from ...domain.data.text_message_content_dto import TextMessageContentDTO
from ...domain import PyrisMessage
from ...llm import CompletionArguments
from ...llm.external.model import ChatModel, CompletionModel, EmbeddingModel


def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]:
def convert_to_ollama_messages(messages: list[PyrisMessage]) -> list[Message]:
return [
Message(role=message.role.value, content=message.text) for message in messages
Message(
role=map_role_to_str(message.sender),
content=message.contents[0].text_content,
)
for message in messages
]


def convert_to_iris_message(message: Message) -> IrisMessage:
return IrisMessage(role=IrisMessageRole(message["role"]), text=message["content"])
def convert_to_iris_message(message: Message) -> PyrisMessage:
contents = [TextMessageContentDTO(text_content=message["content"])]
return PyrisMessage(
sender=map_str_to_role(message["role"]),
contents=contents,
send_at=datetime.now(),
)


class OllamaModel(
Expand All @@ -35,8 +47,8 @@ def complete(self, prompt: str, arguments: CompletionArguments) -> str:
return response["response"]

def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
self, messages: list[PyrisMessage], arguments: CompletionArguments
) -> PyrisMessage:
response = self._client.chat(
model=self.model, messages=convert_to_ollama_messages(messages)
)
Expand Down
27 changes: 18 additions & 9 deletions app/llm/external/openai_chat.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,35 @@
from datetime import datetime
from typing import Literal, Any

from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage

from ...domain import IrisMessage, IrisMessageRole
from ...common.message_converters import map_role_to_str, map_str_to_role
from app.domain.data.text_message_content_dto import TextMessageContentDTO
from ...domain import PyrisMessage
from ...llm import CompletionArguments
from ...llm.external.model import ChatModel


def convert_to_open_ai_messages(
messages: list[IrisMessage],
messages: list[PyrisMessage],
) -> list[ChatCompletionMessageParam]:
return [
{"role": message.role.value, "content": message.text} for message in messages
{
"role": map_role_to_str(message.sender),
"content": message.contents[0].text_content,
}
for message in messages
]


def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage:
# Get IrisMessageRole from the string message.role
message_role = IrisMessageRole(message.role)
return IrisMessage(role=message_role, text=message.content)
def convert_to_iris_message(message: ChatCompletionMessage) -> PyrisMessage:
return PyrisMessage(
sender=map_str_to_role(message.role),
contents=[TextMessageContentDTO(textContent=message.content)],
send_at=datetime.now(),
)


class OpenAIChatModel(ChatModel):
Expand All @@ -29,8 +38,8 @@ class OpenAIChatModel(ChatModel):
_client: OpenAI

def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
self, messages: list[PyrisMessage], arguments: CompletionArguments
) -> PyrisMessage:
response = self._client.chat.completions.create(
model=self.model,
messages=convert_to_open_ai_messages(messages),
Expand Down
6 changes: 3 additions & 3 deletions app/llm/request_handler/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from app.domain import IrisMessage
from app.domain import PyrisMessage
from app.llm.request_handler import RequestHandler
from app.llm.completion_arguments import CompletionArguments
from app.llm.llm_manager import LlmManager
Expand All @@ -17,8 +17,8 @@ def complete(self, prompt: str, arguments: CompletionArguments) -> str:
return llm.complete(prompt, arguments)

def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
self, messages: list[PyrisMessage], arguments: CompletionArguments
) -> PyrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model_id)
return llm.chat(messages, arguments)

Expand Down
6 changes: 3 additions & 3 deletions app/llm/request_handler/capability_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum

from app.domain import IrisMessage
from app.domain import PyrisMessage
from app.llm.capability import RequirementList
from app.llm.external.model import (
ChatModel,
Expand Down Expand Up @@ -41,8 +41,8 @@ def complete(self, prompt: str, arguments: CompletionArguments) -> str:
return llm.complete(prompt, arguments)

def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
self, messages: list[PyrisMessage], arguments: CompletionArguments
) -> PyrisMessage:
llm = self._select_model(ChatModel)
return llm.chat(messages, arguments)

Expand Down
Loading
Loading