Skip to content

Commit

Permalink
Exercise Chat: Implement native function calling agent (#154)
Browse files Browse the repository at this point in the history
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
kaancayli and coderabbitai[bot] authored Nov 26, 2024
1 parent 2cffe1c commit d7f5bf0
Show file tree
Hide file tree
Showing 36 changed files with 1,550 additions and 218 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

.DS_Store
1 change: 1 addition & 0 deletions app/common/PipelineEnum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class PipelineEnum(str, Enum):
IRIS_CODE_FEEDBACK = "IRIS_CODE_FEEDBACK"
IRIS_CHAT_COURSE_MESSAGE = "IRIS_CHAT_COURSE_MESSAGE"
IRIS_CHAT_EXERCISE_MESSAGE = "IRIS_CHAT_EXERCISE_MESSAGE"
IRIS_CHAT_EXERCISE_AGENT_MESSAGE = "IRIS_CHAT_EXERCISE_AGENT_MESSAGE"
IRIS_INTERACTION_SUGGESTION = "IRIS_INTERACTION_SUGGESTION"
IRIS_CHAT_LECTURE_MESSAGE = "IRIS_CHAT_LECTURE_MESSAGE"
IRIS_COMPETENCY_GENERATION = "IRIS_COMPETENCY_GENERATION"
Expand Down
4 changes: 0 additions & 4 deletions app/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
from app.common.singleton import Singleton
from app.common.message_converters import (
convert_iris_message_to_langchain_message,
convert_langchain_message_to_iris_message,
)
128 changes: 113 additions & 15 deletions app/common/message_converters.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
import json
from datetime import datetime
from typing import Literal
from typing import Literal, List

from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.messages import (
BaseMessage,
HumanMessage,
AIMessage,
SystemMessage,
ToolMessage,
ToolCall,
)

from app.common.pyris_message import (
PyrisMessage,
PyrisAIMessage,
IrisMessageRole,
PyrisToolMessage,
)
from app.domain.data.text_message_content_dto import TextMessageContentDTO
from app.common.pyris_message import PyrisMessage, IrisMessageRole
from app.domain.data.tool_call_dto import ToolCallDTO, FunctionDTO
from app.domain.data.tool_message_content_dto import ToolMessageContentDTO


def convert_iris_message_to_langchain_message(
iris_message: PyrisMessage,
) -> BaseMessage:
if len(iris_message.contents) == 0:
if iris_message is None or len(iris_message.contents) == 0:
raise ValueError("IrisMessage contents must not be empty")
message = iris_message.contents[0]
# Check if the message is of type TextMessageContentDTO
Expand All @@ -20,41 +35,122 @@ def convert_iris_message_to_langchain_message(
case IrisMessageRole.USER:
return HumanMessage(content=message.text_content)
case IrisMessageRole.ASSISTANT:
if isinstance(iris_message, PyrisAIMessage):
tool_calls = [
ToolCall(
name=tc.function.name,
args=tc.function.arguments,
id=tc.id,
)
for tc in iris_message.tool_calls
]
return AIMessage(content=message.text_content, tool_calls=tool_calls)
return AIMessage(content=message.text_content)
case IrisMessageRole.SYSTEM:
return SystemMessage(content=message.text_content)
case _:
raise ValueError(f"Unknown message role: {iris_message.sender}")


def convert_iris_message_to_langchain_human_message(
iris_message: PyrisMessage,
) -> HumanMessage:
if len(iris_message.contents) == 0:
raise ValueError("IrisMessage contents must not be empty")
message = iris_message.contents[0]
# Check if the message is of type TextMessageContentDTO
if not isinstance(message, TextMessageContentDTO):
raise ValueError("Message must be of type TextMessageContentDTO")
return HumanMessage(content=message.text_content)


def extract_text_from_iris_message(iris_message: PyrisMessage) -> str:
if len(iris_message.contents) == 0:
raise ValueError("IrisMessage contents must not be empty")
message = iris_message.contents[0]
# Check if the message is of type TextMessageContentDTO
if not isinstance(message, TextMessageContentDTO):
raise ValueError("Message must be of type TextMessageContentDTO")
return message.text_content


def convert_langchain_tool_calls_to_iris_tool_calls(
tool_calls: List[ToolCall],
) -> List[ToolCallDTO]:
return [
ToolCallDTO(
function=FunctionDTO(
name=tc["name"],
arguments=json.dumps(tc["args"]),
),
id=tc["id"],
)
for tc in tool_calls
]


def convert_langchain_message_to_iris_message(
base_message: BaseMessage,
) -> PyrisMessage:
match base_message.type:
case "human":
role = IrisMessageRole.USER
case "ai":
role = IrisMessageRole.ASSISTANT
case "system":
role = IrisMessageRole.SYSTEM
case _:
raise ValueError(f"Unknown message type: {base_message.type}")
contents = [TextMessageContentDTO(textContent=base_message.content)]
type_to_role = {
"human": IrisMessageRole.USER,
"ai": IrisMessageRole.ASSISTANT,
"system": IrisMessageRole.SYSTEM,
"tool": IrisMessageRole.TOOL,
}

role = type_to_role.get(base_message.type)
if role is None:
raise ValueError(f"Unknown message type: {base_message.type}")

if isinstance(base_message, (HumanMessage, SystemMessage)):
contents = [TextMessageContentDTO(textContent=base_message.content)]
elif isinstance(base_message, AIMessage):
if base_message.tool_calls:
contents = [TextMessageContentDTO(textContent=base_message.content)]
tool_calls = convert_langchain_tool_calls_to_iris_tool_calls(
base_message.tool_calls
)
return PyrisAIMessage(
contents=contents,
tool_calls=tool_calls,
send_at=datetime.now(),
)
else:
contents = [TextMessageContentDTO(textContent=base_message.content)]
elif isinstance(base_message, ToolMessage):
contents = [
ToolMessageContentDTO(
toolContent=base_message.content,
toolName=base_message.additional_kwargs["name"],
toolCallId=base_message.tool_call_id,
)
]
return PyrisToolMessage(
contents=contents,
send_at=datetime.now(),
)
else:
raise ValueError(f"Unknown message type: {type(base_message)}")
return PyrisMessage(
contents=contents,
sender=role,
send_at=datetime.now(),
)


def map_role_to_str(role: IrisMessageRole) -> Literal["user", "assistant", "system"]:
def map_role_to_str(
role: IrisMessageRole,
) -> Literal["user", "assistant", "system", "tool"]:
match role:
case IrisMessageRole.USER:
return "user"
case IrisMessageRole.ASSISTANT:
return "assistant"
case IrisMessageRole.SYSTEM:
return "system"
case IrisMessageRole.TOOL:
return "tool"
case _:
raise ValueError(f"Unknown message role: {role}")

Expand All @@ -67,5 +163,7 @@ def map_str_to_role(role: str) -> IrisMessageRole:
return IrisMessageRole.ASSISTANT
case "system":
return IrisMessageRole.SYSTEM
case "tool":
return IrisMessageRole.TOOL
case _:
raise ValueError(f"Unknown message role: {role}")
20 changes: 18 additions & 2 deletions app/common/pyris_message.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from datetime import datetime
from enum import Enum
from typing import List
from typing import List, Optional

from pydantic import BaseModel, ConfigDict, Field

from app.domain.data.message_content_dto import MessageContentDTO
from app.common.token_usage_dto import TokenUsageDTO
from app.domain.data.tool_call_dto import ToolCallDTO
from app.domain.data.tool_message_content_dto import ToolMessageContentDTO


class IrisMessageRole(str, Enum):
USER = "USER"
ASSISTANT = "LLM"
SYSTEM = "SYSTEM"
TOOL = "TOOL"


class PyrisMessage(BaseModel):
Expand All @@ -21,7 +24,20 @@ class PyrisMessage(BaseModel):

sent_at: datetime | None = Field(alias="sentAt", default=None)
sender: IrisMessageRole
contents: List[MessageContentDTO] = []

contents: List[MessageContentDTO] = Field(default=[])

def __str__(self):
return f"{self.sender.lower()}: {self.contents}"


class PyrisAIMessage(PyrisMessage):
model_config = ConfigDict(populate_by_name=True)
sender: IrisMessageRole = IrisMessageRole.ASSISTANT
tool_calls: Optional[List[ToolCallDTO]] = Field(alias="toolCalls")


class PyrisToolMessage(PyrisMessage):
model_config = ConfigDict(populate_by_name=True)
sender: IrisMessageRole = IrisMessageRole.TOOL
contents: List[ToolMessageContentDTO] = Field(default=[])
2 changes: 1 addition & 1 deletion app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_settings(cls):
try:
with open(file_path, "r") as file:
settings_file = yaml.safe_load(file)
return cls.parse_obj(settings_file)
return cls.model_validate(settings_file)
except FileNotFoundError as e:
raise FileNotFoundError(
f"Configuration file not found at {file_path}."
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Optional
from typing import Optional, Any

from pydantic import Field

from ..chat_pipeline_execution_dto import ChatPipelineExecutionDTO
from ...data.extended_course_dto import ExtendedCourseDTO
from ...data.metrics.competency_jol_dto import CompetencyJolDTO
from ...data.metrics.student_metrics_dto import StudentMetricsDTO
from ...event.pyris_event_dto import PyrisEventDTO


class CourseChatPipelineExecutionDTO(ChatPipelineExecutionDTO):
course: ExtendedCourseDTO
metrics: Optional[StudentMetricsDTO]
competency_jol: Optional[CompetencyJolDTO] = Field(None, alias="competencyJol")
event_payload: Optional[PyrisEventDTO[Any]] = Field(None, alias="eventPayload")
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import Optional
from typing import Optional, Any

from pydantic import Field

from app.domain.chat.chat_pipeline_execution_dto import ChatPipelineExecutionDTO
from app.domain.data.course_dto import CourseDTO
from app.domain.data.programming_exercise_dto import ProgrammingExerciseDTO
from app.domain.data.programming_submission_dto import ProgrammingSubmissionDTO
from app.domain.event.pyris_event_dto import PyrisEventDTO


class ExerciseChatPipelineExecutionDTO(ChatPipelineExecutionDTO):
submission: Optional[ProgrammingSubmissionDTO] = None
exercise: ProgrammingExerciseDTO
course: CourseDTO
event_payload: Optional[PyrisEventDTO[Any]] = Field(None, alias="eventPayload")
3 changes: 2 additions & 1 deletion app/domain/data/competency_dto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Optional
from typing import Optional, List

from pydantic import BaseModel, Field
from pydantic.v1 import validator
Expand All @@ -22,6 +22,7 @@ class CompetencyDTO(BaseModel):
taxonomy: Optional[CompetencyTaxonomy] = None
soft_due_date: Optional[datetime] = Field(default=None, alias="softDueDate")
optional: Optional[bool] = None
exercise_list: Optional[List[int]] = Field(default=[], alias="exerciseList")


class Competency(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions app/domain/data/exercise_with_submissions_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class IncludedInOverallScore(str, Enum):

class ExerciseWithSubmissionsDTO(BaseModel):
id: int = Field(alias="id")
url: Optional[str] = Field(alias="url", default=None)
title: str = Field(alias="title")
type: ExerciseType = Field(alias="type")
mode: ExerciseMode = Field(alias="mode")
Expand Down
6 changes: 5 additions & 1 deletion app/domain/data/message_content_dto.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Union

from .tool_message_content_dto import ToolMessageContentDTO
from ...domain.data.image_message_content_dto import ImageMessageContentDTO
from ...domain.data.json_message_content_dto import JsonMessageContentDTO
from ...domain.data.text_message_content_dto import TextMessageContentDTO

MessageContentDTO = Union[
TextMessageContentDTO, ImageMessageContentDTO, JsonMessageContentDTO
TextMessageContentDTO,
ImageMessageContentDTO,
JsonMessageContentDTO,
ToolMessageContentDTO,
]
6 changes: 6 additions & 0 deletions app/domain/data/programming_exercise_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,9 @@ class ProgrammingExerciseDTO(BaseModel):
problem_statement: str = Field(alias="problemStatement", default=None)
start_date: Optional[datetime] = Field(alias="startDate", default=None)
end_date: Optional[datetime] = Field(alias="endDate", default=None)
max_points: Optional[float] = Field(alias="maxPoints", default=None)
recent_changes: Optional[str] = Field(
alias="recentChanges",
default=None,
description="Git diff of the recent changes",
)
16 changes: 16 additions & 0 deletions app/domain/data/tool_call_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Literal, Any

from pydantic import BaseModel, ConfigDict, Field, Json


class FunctionDTO(BaseModel):
name: str = Field(..., alias="name")
arguments: Json[Any] = Field(..., alias="arguments")


class ToolCallDTO(BaseModel):

model_config = ConfigDict(populate_by_name=True)
id: str = Field(alias="id")
type: Literal["function"] = "function"
function: FunctionDTO = Field(alias="function")
11 changes: 11 additions & 0 deletions app/domain/data/tool_message_content_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field


class ToolMessageContentDTO(BaseModel):

model_config = ConfigDict(populate_by_name=True)
name: Optional[str] = Field(alias="toolName", default="")
tool_content: str = Field(alias="toolContent")
tool_call_id: str = Field(alias="toolCallId")
10 changes: 10 additions & 0 deletions app/domain/event/pyris_event_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import TypeVar, Generic, Optional

from pydantic import Field, BaseModel

T = TypeVar("T")


class PyrisEventDTO(BaseModel, Generic[T]):
event_type: Optional[str] = Field(default=None, alias="eventType")
event: Optional[T] = Field(default=None, alias="event")
Loading

0 comments on commit d7f5bf0

Please sign in to comment.