Skip to content

Commit

Permalink
Adjust file selector pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
kaancayli committed Feb 28, 2024
1 parent ac8dcbe commit cda7905
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 37 deletions.
21 changes: 20 additions & 1 deletion app/common/message_converters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
from datetime import datetime

from langchain_core.messages import BaseMessage

from domain import IrisMessage, IrisMessageRole
from domain.iris_message import IrisMessage, IrisMessageRole
from domain.data.message_dto import MessageDTO, MessageContentDTO, IrisMessageSender


def convert_iris_message_to_message_dto(iris_message: IrisMessage) -> MessageDTO:
match iris_message.role:
case "user":
sender = IrisMessageSender.USER
case "assistant":
sender = IrisMessageSender.LLM
case _:
raise ValueError(f"Unknown message role: {iris_message.role}")

return MessageDTO(
sent_at=datetime.now(),
sender=sender,
contents=[MessageContentDTO(textContent=iris_message.text)],
)


def convert_iris_message_to_langchain_message(iris_message: IrisMessage) -> BaseMessage:
Expand Down
3 changes: 3 additions & 0 deletions app/domain/data/build_log_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@
class BuildLogEntryDTO(BaseModel):
timestamp: datetime
message: str

def __str__(self):
return f"{self.timestamp}: {self.message}"
27 changes: 25 additions & 2 deletions app/domain/data/message_dto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import datetime
from enum import Enum
from typing import List
from typing import List, Literal

from domain.iris_message import IrisMessage
from message_content_dto import MessageContentDTO

from pydantic import BaseModel, Field
Expand All @@ -13,5 +15,26 @@ class IrisMessageSender(str, Enum):

class MessageDTO(BaseModel):
sent_at: datetime = Field(alias="sentAt")
sender: IrisMessageSender
sender: Literal[IrisMessageSender.USER, IrisMessageSender.LLM]
contents: List[MessageContentDTO]

def __str__(self):
match self.sender:
case IrisMessageSender.USER:
sender = "user"
case IrisMessageSender.LLM:
sender = "ai"
case _:
raise ValueError(f"Unknown message sender: {self.sender}")
return f"{sender}: {self.contents[0].textContent}"

def convert_to_iris_message(self):
match self.sender:
case IrisMessageSender.USER:
sender = "user"
case IrisMessageSender.LLM:
sender = "assistant"
case _:
raise ValueError(f"Unknown message sender: {self.sender}")

return IrisMessage(text=self.contents[0].textContent, role=sender)
8 changes: 5 additions & 3 deletions app/domain/data/programming_exercise_dto.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

from pydantic import BaseModel, Field
from datetime import datetime
from enum import Enum
Expand All @@ -20,9 +22,9 @@ class ProgrammingExerciseDTO(BaseModel):
id: int
name: str
programming_language: ProgrammingLanguage = Field(alias="programmingLanguage")
template_repository: dict[str, str] = Field(alias="templateRepository")
solution_repository: dict[str, str] = Field(alias="solutionRepository")
test_repository: dict[str, str] = Field(alias="testRepository")
template_repository: Dict[str, str] = Field(alias="templateRepository")
solution_repository: Dict[str, str] = Field(alias="solutionRepository")
test_repository: Dict[str, str] = Field(alias="testRepository")
problem_statement: str = Field(alias="problemStatement")
start_date: datetime = Field(alias="startDate")
end_date: datetime = Field(alias="endDate")
4 changes: 2 additions & 2 deletions app/domain/data/submission_dto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict

from pydantic import BaseModel, Field

Expand All @@ -10,7 +10,7 @@
class SubmissionDTO(BaseModel):
id: int
date: datetime
repository: dict[str, str]
repository: Dict[str, str]
is_practice: bool = Field(alias="isPractice")
build_failed: bool = Field(alias="buildFailed")
build_log_entries: List[BuildLogEntryDTO] = Field(alias="buildLogEntries")
Expand Down
20 changes: 20 additions & 0 deletions app/domain/iris_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from enum import Enum
from typing import Literal

from pydantic import BaseModel, Field


class IrisMessageRole(str, Enum):
USER = "USER"
ASSISTANT = "ASSISTANT"
SYSTEM = "SYSTEM"


class IrisMessage(BaseModel):
text: str
role: Literal[
IrisMessageRole.USER, IrisMessageRole.ASSISTANT, IrisMessageRole.SYSTEM
]

def __str__(self):
return f"{self.role.lower()}: {self.text}"
10 changes: 9 additions & 1 deletion app/domain/status/stage_dto.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

from pydantic import BaseModel

from domain.status.stage_state_dto import StageStateDTO
Expand All @@ -6,5 +8,11 @@
class StageDTO(BaseModel):
name: str
weight: int
state: StageStateDTO
state: Literal[
StageStateDTO.NOT_STARTED,
StageStateDTO.IN_PROGRESS,
StageStateDTO.DONE,
StageStateDTO.SKIPPED,
StageStateDTO.ERROR,
]
message: str
29 changes: 16 additions & 13 deletions app/pipeline/chat/file_selector_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import logging
import os
from typing import List
from typing import List, Dict

from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import Runnable
from pydantic import BaseModel

from domain.submission import BuildLogEntry
from domain.data.build_log_entry import BuildLogEntryDTO
from llm import BasicRequestHandler
from llm.langchain import IrisLangchainChatModel
from pipeline import Pipeline
from pipeline.chat.output_models.output_models.selected_file_model import SelectedFile
from web.status.status_update import StatusCallback

logger = logging.getLogger(__name__)


class FileSelectionDTO(BaseModel):
files: List[str]
build_logs: List[BuildLogEntry]
files: Dict[str, str]
build_logs: List[BuildLogEntryDTO]

def __str__(self):
return (
Expand All @@ -32,11 +33,13 @@ class FileSelectorPipeline(Pipeline):

llm: IrisLangchainChatModel
pipeline: Runnable
callback: StatusCallback

def __init__(self):
def __init__(self, callback: StatusCallback):
super().__init__(implementation_id="file_selector_pipeline_reference_impl")
request_handler = BasicRequestHandler("gpt35")
self.llm = IrisLangchainChatModel(request_handler)
self.callback = callback
# Load prompt from file
dirname = os.path.dirname(__file__)
with open(
Expand All @@ -48,26 +51,26 @@ def __init__(self):
# Create the prompt
prompt = PromptTemplate(
template=prompt_str,
input_variables=["repository", "build_log"],
input_variables=["file_names", "build_logs"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
logger.debug(parser.get_format_instructions())
# Create the pipeline
self.pipeline = prompt | self.llm | parser

def __call__(self, dto: FileSelectionDTO, **kwargs) -> SelectedFile:
def __call__(self, dto: FileSelectionDTO, **kwargs) -> str:
"""
Runs the pipeline
:param query: The query
:return: IrisMessage
:return: Selected file content
"""
logger.debug("Running file selector pipeline...")
repository = dto.files
build_log = dto.build_logs
file_names = list(dto.files.keys())
build_logs = dto.build_logs
response = self.pipeline.invoke(
{
"repository": repository,
"build_log": build_log,
"file_names": file_names,
"build_logs": build_logs,
}
)
return response
return dto.files[response.selected_file] if response.selected_file else ""
38 changes: 28 additions & 10 deletions app/pipeline/chat/tutor_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

from domain import TutorChatPipelineExecutionDTO
from domain.data.message_dto import MessageDTO
from domain.iris_message import IrisMessage
from web.status.status_update import TutorChatStatusCallback
from .file_selector_pipeline import FileSelectorPipeline, FileSelectionDTO
from ...llm import BasicRequestHandler
from ...llm.langchain import IrisLangchainChatModel

Expand All @@ -23,17 +25,23 @@
logger = logging.getLogger(__name__)


class IrisMessageRole:
pass


class TutorChatPipeline(Pipeline):
"""Tutor chat pipeline that answers exercises related questions from students."""

llm: IrisLangchainChatModel
pipeline: Runnable
callback: TutorChatStatusCallback

def __init__(self, callback: TutorChatStatusCallback):
super().__init__(implementation_id="tutor_chat_pipeline_reference_impl")
super().__init__(implementation_id="tutor_chat_pipeline")
# Set the langchain chat model
request_handler = BasicRequestHandler("gpt35")
self.llm = IrisLangchainChatModel(request_handler)
self.callback = callback
# Load the prompt from a file
dirname = os.path.dirname(__file__)
with open(
Expand All @@ -51,16 +59,26 @@ def __init__(self, callback: TutorChatStatusCallback):
)
# Create the pipeline
summary_pipeline = SummaryPipeline()
# Create file selector pipeline
file_selector_pipeline = FileSelectorPipeline()
self.pipeline = (
{
"question": itemgetter("question"),
"history": itemgetter("history"),
"exercise_title": itemgetter("exercise_title"),
"summary": itemgetter("problem_statement")
"summary": itemgetter("problem_statement"),
"file_content": itemgetter("file_map")
| RunnableLambda(
lambda file_map: file_selector_pipeline(
dto=FileSelectionDTO(files=file_map.keys(), build_logs=[])
),
callback=None,
)
| RunnableLambda(
lambda stmt: summary_pipeline(query=stmt), callback=None
lambda selected_file: (
itemgetter("file_map")[selected_file] if selected_file else ""
),
),
"file_content": itemgetter("file_content"),
}
| prompt
| self.llm
Expand All @@ -77,25 +95,25 @@ def __call__(self, dto: TutorChatPipelineExecutionDTO, **kwargs):
"""
Runs the pipeline
:param query: The query
:return: IrisMessage
"""
logger.debug("Running tutor chat pipeline...")
logger.debug(f"DTO: {dto}")
history: List[MessageDTO] = dto.chat_history[:-1]
query: MessageDTO = dto.chat_history[-1]
query: IrisMessage = dto.chat_history[-1].convert_to_iris_message()
problem_statement: str = dto.exercise.problem_statement
exercise_title: str = dto.exercise.title
message = query.contents[0].textContent
exercise_title: str = dto.exercise.name
message = query.text
file_map = dto.latest_submission.repository
if not message:
raise ValueError("IrisMessage must not be empty")
response = self.pipeline.invoke(
{
"question": message,
"history": [message.__str__() for message in history],
"problem_statement": problem_statement,
"file_content": "", # TODO add file selector pipeline and get file content
"file_map": file_map,
"exercise_title": exercise_title,
}
)
logger.debug(f"Response from tutor chat pipeline: {response}")
return MessageDTO(role=IrisMessageRole.ASSISTANT, text=response)
# TODO: Convert response to status update
4 changes: 2 additions & 2 deletions app/pipeline/prompts/file_selector_prompt.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
Select a file from the files list below that is mentioned in the build logs. If no file is in the list is mentioned in the build logs answer with empty string.

Here are the paths of all files:
{repository}
{file_names}

Build logs:
{build_log}
{build_logs}

{format_instructions}

Expand Down
6 changes: 3 additions & 3 deletions app/pipeline/prompts/iris_tutor_chat_prompt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ A: I am Iris, the AI programming tutor integrated into Artemis, the online learn

Consider the following exercise context:
- Title: {exercise_title}
- Problem Statement: {summary}
- Exercise skeleton code in markdown format:
```java
- Problem Statement: {problem_statement}
- Exercise skeleton code in Markdown format:
```[{programming_language}]
{file_content}
```

Expand Down

0 comments on commit cda7905

Please sign in to comment.