diff --git a/app/common/message_converters.py b/app/common/message_converters.py index 6835ec11..7146f5f3 100644 --- a/app/common/message_converters.py +++ b/app/common/message_converters.py @@ -1,6 +1,25 @@ +from datetime import datetime + from langchain_core.messages import BaseMessage -from domain import IrisMessage, IrisMessageRole +from domain.iris_message import IrisMessage, IrisMessageRole +from domain.data.message_dto import MessageDTO, MessageContentDTO, IrisMessageSender + + +def convert_iris_message_to_message_dto(iris_message: IrisMessage) -> MessageDTO: + match iris_message.role: + case "user": + sender = IrisMessageSender.USER + case "assistant": + sender = IrisMessageSender.LLM + case _: + raise ValueError(f"Unknown message role: {iris_message.role}") + + return MessageDTO( + sent_at=datetime.now(), + sender=sender, + contents=[MessageContentDTO(textContent=iris_message.text)], + ) def convert_iris_message_to_langchain_message(iris_message: IrisMessage) -> BaseMessage: diff --git a/app/domain/data/build_log_entry.py b/app/domain/data/build_log_entry.py index 623e7bed..b73c942c 100644 --- a/app/domain/data/build_log_entry.py +++ b/app/domain/data/build_log_entry.py @@ -5,3 +5,6 @@ class BuildLogEntryDTO(BaseModel): timestamp: datetime message: str + + def __str__(self): + return f"{self.timestamp}: {self.message}" diff --git a/app/domain/data/message_dto.py b/app/domain/data/message_dto.py index dc3da7ad..d6892e49 100644 --- a/app/domain/data/message_dto.py +++ b/app/domain/data/message_dto.py @@ -1,6 +1,8 @@ from datetime import datetime from enum import Enum -from typing import List +from typing import List, Literal + +from domain.iris_message import IrisMessage from message_content_dto import MessageContentDTO from pydantic import BaseModel, Field @@ -13,5 +15,26 @@ class IrisMessageSender(str, Enum): class MessageDTO(BaseModel): sent_at: datetime = Field(alias="sentAt") - sender: IrisMessageSender + sender: Literal[IrisMessageSender.USER, IrisMessageSender.LLM] contents: List[MessageContentDTO] + + def __str__(self): + match self.sender: + case IrisMessageSender.USER: + sender = "user" + case IrisMessageSender.LLM: + sender = "ai" + case _: + raise ValueError(f"Unknown message sender: {self.sender}") + return f"{sender}: {self.contents[0].textContent}" + + def convert_to_iris_message(self): + match self.sender: + case IrisMessageSender.USER: + sender = "user" + case IrisMessageSender.LLM: + sender = "assistant" + case _: + raise ValueError(f"Unknown message sender: {self.sender}") + + return IrisMessage(text=self.contents[0].textContent, role=sender) diff --git a/app/domain/data/programming_exercise_dto.py b/app/domain/data/programming_exercise_dto.py index 64856e63..eac216da 100644 --- a/app/domain/data/programming_exercise_dto.py +++ b/app/domain/data/programming_exercise_dto.py @@ -1,3 +1,5 @@ +from typing import Dict + from pydantic import BaseModel, Field from datetime import datetime from enum import Enum @@ -20,9 +22,9 @@ class ProgrammingExerciseDTO(BaseModel): id: int name: str programming_language: ProgrammingLanguage = Field(alias="programmingLanguage") - template_repository: dict[str, str] = Field(alias="templateRepository") - solution_repository: dict[str, str] = Field(alias="solutionRepository") - test_repository: dict[str, str] = Field(alias="testRepository") + template_repository: Dict[str, str] = Field(alias="templateRepository") + solution_repository: Dict[str, str] = Field(alias="solutionRepository") + test_repository: Dict[str, str] = Field(alias="testRepository") problem_statement: str = Field(alias="problemStatement") start_date: datetime = Field(alias="startDate") end_date: datetime = Field(alias="endDate") diff --git a/app/domain/data/submission_dto.py b/app/domain/data/submission_dto.py index dd9b562c..d275daab 100644 --- a/app/domain/data/submission_dto.py +++ b/app/domain/data/submission_dto.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict from pydantic import BaseModel, Field @@ -10,7 +10,7 @@ class SubmissionDTO(BaseModel): id: int date: datetime - repository: dict[str, str] + repository: Dict[str, str] is_practice: bool = Field(alias="isPractice") build_failed: bool = Field(alias="buildFailed") build_log_entries: List[BuildLogEntryDTO] = Field(alias="buildLogEntries") diff --git a/app/domain/iris_message.py b/app/domain/iris_message.py new file mode 100644 index 00000000..8994b068 --- /dev/null +++ b/app/domain/iris_message.py @@ -0,0 +1,20 @@ +from enum import Enum +from typing import Literal + +from pydantic import BaseModel, Field + + +class IrisMessageRole(str, Enum): + USER = "USER" + ASSISTANT = "ASSISTANT" + SYSTEM = "SYSTEM" + + +class IrisMessage(BaseModel): + text: str + role: Literal[ + IrisMessageRole.USER, IrisMessageRole.ASSISTANT, IrisMessageRole.SYSTEM + ] + + def __str__(self): + return f"{self.role.lower()}: {self.text}" diff --git a/app/domain/status/stage_dto.py b/app/domain/status/stage_dto.py index c6a5b8a1..65c237f4 100644 --- a/app/domain/status/stage_dto.py +++ b/app/domain/status/stage_dto.py @@ -1,3 +1,5 @@ +from typing import Literal + from pydantic import BaseModel from domain.status.stage_state_dto import StageStateDTO @@ -6,5 +8,11 @@ class StageDTO(BaseModel): name: str weight: int - state: StageStateDTO + state: Literal[ + StageStateDTO.NOT_STARTED, + StageStateDTO.IN_PROGRESS, + StageStateDTO.DONE, + StageStateDTO.SKIPPED, + StageStateDTO.ERROR, + ] message: str diff --git a/app/pipeline/chat/file_selector_pipeline.py b/app/pipeline/chat/file_selector_pipeline.py index 078e8548..59927d40 100644 --- a/app/pipeline/chat/file_selector_pipeline.py +++ b/app/pipeline/chat/file_selector_pipeline.py @@ -1,24 +1,25 @@ import logging import os -from typing import List +from typing import List, Dict from langchain.output_parsers import PydanticOutputParser from langchain_core.prompts import PromptTemplate from langchain_core.runnables import Runnable from pydantic import BaseModel -from domain.submission import BuildLogEntry +from domain.data.build_log_entry import BuildLogEntryDTO from llm import BasicRequestHandler from llm.langchain import IrisLangchainChatModel from pipeline import Pipeline from pipeline.chat.output_models.output_models.selected_file_model import SelectedFile +from web.status.status_update import StatusCallback logger = logging.getLogger(__name__) class FileSelectionDTO(BaseModel): - files: List[str] - build_logs: List[BuildLogEntry] + files: Dict[str, str] + build_logs: List[BuildLogEntryDTO] def __str__(self): return ( @@ -32,11 +33,13 @@ class FileSelectorPipeline(Pipeline): llm: IrisLangchainChatModel pipeline: Runnable + callback: StatusCallback - def __init__(self): + def __init__(self, callback: StatusCallback): super().__init__(implementation_id="file_selector_pipeline_reference_impl") request_handler = BasicRequestHandler("gpt35") self.llm = IrisLangchainChatModel(request_handler) + self.callback = callback # Load prompt from file dirname = os.path.dirname(__file__) with open( @@ -48,26 +51,26 @@ def __init__(self): # Create the prompt prompt = PromptTemplate( template=prompt_str, - input_variables=["repository", "build_log"], + input_variables=["file_names", "build_logs"], partial_variables={"format_instructions": parser.get_format_instructions()}, ) logger.debug(parser.get_format_instructions()) # Create the pipeline self.pipeline = prompt | self.llm | parser - def __call__(self, dto: FileSelectionDTO, **kwargs) -> SelectedFile: + def __call__(self, dto: FileSelectionDTO, **kwargs) -> str: """ Runs the pipeline :param query: The query - :return: IrisMessage + :return: Selected file content """ logger.debug("Running file selector pipeline...") - repository = dto.files - build_log = dto.build_logs + file_names = list(dto.files.keys()) + build_logs = dto.build_logs response = self.pipeline.invoke( { - "repository": repository, - "build_log": build_log, + "file_names": file_names, + "build_logs": build_logs, } ) - return response + return dto.files[response.selected_file] if response.selected_file else "" diff --git a/app/pipeline/chat/tutor_chat_pipeline.py b/app/pipeline/chat/tutor_chat_pipeline.py index 7b0f75e7..a2b7f19e 100644 --- a/app/pipeline/chat/tutor_chat_pipeline.py +++ b/app/pipeline/chat/tutor_chat_pipeline.py @@ -13,7 +13,9 @@ from domain import TutorChatPipelineExecutionDTO from domain.data.message_dto import MessageDTO +from domain.iris_message import IrisMessage from web.status.status_update import TutorChatStatusCallback +from .file_selector_pipeline import FileSelectorPipeline, FileSelectionDTO from ...llm import BasicRequestHandler from ...llm.langchain import IrisLangchainChatModel @@ -23,17 +25,23 @@ logger = logging.getLogger(__name__) +class IrisMessageRole: + pass + + class TutorChatPipeline(Pipeline): """Tutor chat pipeline that answers exercises related questions from students.""" llm: IrisLangchainChatModel pipeline: Runnable + callback: TutorChatStatusCallback def __init__(self, callback: TutorChatStatusCallback): - super().__init__(implementation_id="tutor_chat_pipeline_reference_impl") + super().__init__(implementation_id="tutor_chat_pipeline") # Set the langchain chat model request_handler = BasicRequestHandler("gpt35") self.llm = IrisLangchainChatModel(request_handler) + self.callback = callback # Load the prompt from a file dirname = os.path.dirname(__file__) with open( @@ -51,16 +59,26 @@ def __init__(self, callback: TutorChatStatusCallback): ) # Create the pipeline summary_pipeline = SummaryPipeline() + # Create file selector pipeline + file_selector_pipeline = FileSelectorPipeline() self.pipeline = ( { "question": itemgetter("question"), "history": itemgetter("history"), "exercise_title": itemgetter("exercise_title"), - "summary": itemgetter("problem_statement") + "summary": itemgetter("problem_statement"), + "file_content": itemgetter("file_map") + | RunnableLambda( + lambda file_map: file_selector_pipeline( + dto=FileSelectionDTO(files=file_map.keys(), build_logs=[]) + ), + callback=None, + ) | RunnableLambda( - lambda stmt: summary_pipeline(query=stmt), callback=None + lambda selected_file: ( + itemgetter("file_map")[selected_file] if selected_file else "" + ), ), - "file_content": itemgetter("file_content"), } | prompt | self.llm @@ -77,15 +95,15 @@ def __call__(self, dto: TutorChatPipelineExecutionDTO, **kwargs): """ Runs the pipeline :param query: The query - :return: IrisMessage """ logger.debug("Running tutor chat pipeline...") logger.debug(f"DTO: {dto}") history: List[MessageDTO] = dto.chat_history[:-1] - query: MessageDTO = dto.chat_history[-1] + query: IrisMessage = dto.chat_history[-1].convert_to_iris_message() problem_statement: str = dto.exercise.problem_statement - exercise_title: str = dto.exercise.title - message = query.contents[0].textContent + exercise_title: str = dto.exercise.name + message = query.text + file_map = dto.latest_submission.repository if not message: raise ValueError("IrisMessage must not be empty") response = self.pipeline.invoke( @@ -93,9 +111,9 @@ def __call__(self, dto: TutorChatPipelineExecutionDTO, **kwargs): "question": message, "history": [message.__str__() for message in history], "problem_statement": problem_statement, - "file_content": "", # TODO add file selector pipeline and get file content + "file_map": file_map, "exercise_title": exercise_title, } ) logger.debug(f"Response from tutor chat pipeline: {response}") - return MessageDTO(role=IrisMessageRole.ASSISTANT, text=response) + # TODO: Convert response to status update diff --git a/app/pipeline/prompts/file_selector_prompt.txt b/app/pipeline/prompts/file_selector_prompt.txt index 6ea4260a..254244c7 100644 --- a/app/pipeline/prompts/file_selector_prompt.txt +++ b/app/pipeline/prompts/file_selector_prompt.txt @@ -1,10 +1,10 @@ Select a file from the files list below that is mentioned in the build logs. If no file is in the list is mentioned in the build logs answer with empty string. Here are the paths of all files: -{repository} +{file_names} Build logs: -{build_log} +{build_logs} {format_instructions} diff --git a/app/pipeline/prompts/iris_tutor_chat_prompt.txt b/app/pipeline/prompts/iris_tutor_chat_prompt.txt index 4188951d..0c16e7f4 100644 --- a/app/pipeline/prompts/iris_tutor_chat_prompt.txt +++ b/app/pipeline/prompts/iris_tutor_chat_prompt.txt @@ -37,9 +37,9 @@ A: I am Iris, the AI programming tutor integrated into Artemis, the online learn Consider the following exercise context: - Title: {exercise_title} - - Problem Statement: {summary} - - Exercise skeleton code in markdown format: - ```java + - Problem Statement: {problem_statement} + - Exercise skeleton code in Markdown format: + ```[{programming_language}] {file_content} ```