Skip to content

Commit

Permalink
chore: some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
kaancayli committed Nov 11, 2024
1 parent 29d9252 commit df9eca0
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 128 deletions.
287 changes: 159 additions & 128 deletions app/llm/external/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from pydantic import Field
from pydantic.v1 import BaseModel as LegacyBaseModel

from ...common.message_converters import map_str_to_role, map_role_to_str
from app.domain.data.text_message_content_dto import TextMessageContentDTO
from ...common.message_converters import map_role_to_str, map_str_to_role
from ...common.pyris_message import PyrisMessage, PyrisAIMessage
from ...common.token_usage_dto import TokenUsageDTO
from ...domain.data.image_message_content_dto import ImageMessageContentDTO
Expand All @@ -32,118 +32,171 @@
from ...llm.external.model import ChatModel


def convert_to_open_ai_messages(
def convert_content_to_openai_format(content):
"""Convert a single content item to OpenAI format."""
content_type_mapping = {
ImageMessageContentDTO: lambda c: {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{c.base64}",
"detail": "high",
},
},
TextMessageContentDTO: lambda c: {"type": "text", "text": c.text_content},
JsonMessageContentDTO: lambda c: {
"type": "json_object",
"json_object": c.json_content,
},
}

converter = content_type_mapping.get(type(content))
return converter(content) if converter else None


def handle_tool_message(content):
"""Handle tool-specific message conversion."""
if isinstance(content, ToolMessageContentDTO):
return {
"role": "tool",
"content": content.tool_content,
"tool_call_id": content.tool_call_id,
}
return None


def create_openai_tool_calls(tool_calls):
"""Convert tool calls to OpenAI format."""
return [
{
"id": tool.id,
"type": tool.type,
"function": {
"name": tool.function.name,
"arguments": json.dumps(tool.function.arguments),
},
}
for tool in tool_calls
]


def convert_to_openai_messages(
messages: list[PyrisMessage],
) -> list[ChatCompletionMessageParam]:
"""
Convert a list of PyrisMessage to a list of ChatCompletionMessageParam
Convert a list of PyrisMessage to a list of ChatCompletionMessageParam.
Args:
messages: List of PyrisMessage objects to convert
Returns:
List of messages in OpenAI's format
"""
openai_messages = []

for message in messages:
if message.sender == "TOOL":
# Handle tool messages
for content in message.contents:
tool_message = handle_tool_message(content)
if tool_message:
openai_messages.append(tool_message)
continue

# Handle regular messages
openai_content = []
for content in message.contents:
if message.sender == "TOOL":
match content:
case ToolMessageContentDTO():
openai_messages.append(
{
"role": "tool",
"content": content.tool_content,
"tool_call_id": content.tool_call_id,
}
)
case _:
pass
else:
match content:
case ImageMessageContentDTO():
openai_content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{content.base64}",
"detail": "high",
},
}
)
case TextMessageContentDTO():
openai_content.append(
{"type": "text", "text": content.text_content}
)
case JsonMessageContentDTO():
openai_content.append(
{
"type": "json_object",
"json_object": content.json_content,
}
)
case _:
pass

if isinstance(message, PyrisAIMessage) and message.tool_calls:
openai_message = {
"role": map_role_to_str(message.sender),
"content": openai_content,
"tool_calls": [
{
"id": tool.id,
"type": tool.type,
"function": {
"name": tool.function.name,
"arguments": json.dumps(tool.function.arguments),
},
}
for tool in message.tool_calls
],
}
else:
openai_message = {
"role": map_role_to_str(message.sender),
"content": openai_content,
}
openai_messages.append(openai_message)
formatted_content = convert_content_to_openai_format(content)
if formatted_content:
openai_content.append(formatted_content)

# Create the message object
openai_message = {
"role": map_role_to_str(message.sender),
"content": openai_content,
}

# Add tool calls if present
if isinstance(message, PyrisAIMessage) and message.tool_calls:
openai_message["tool_calls"] = create_openai_tool_calls(message.tool_calls)

openai_messages.append(openai_message)

return openai_messages


def create_token_usage(usage: Optional[CompletionUsage], model: str) -> TokenUsageDTO:
"""
Create a TokenUsageDTO from CompletionUsage data.
Args:
usage: Optional CompletionUsage containing token counts
model: The model name used for the completion
Returns:
TokenUsageDTO with the token usage information
"""
return TokenUsageDTO(
model=model,
numInputTokens=getattr(usage, "prompt_tokens", 0),
numOutputTokens=getattr(usage, "completion_tokens", 0),
)


def create_iris_tool_calls(message_tool_calls) -> list[ToolCallDTO]:
"""
Convert OpenAI tool calls to Iris format.
Args:
message_tool_calls: List of tool calls from ChatCompletionMessage
Returns:
List of ToolCallDTO objects
"""
return [
ToolCallDTO(
id=tc.id,
type=tc.type,
function={
"name": tc.function.name,
"arguments": tc.function.arguments,
},
)
for tc in message_tool_calls
]


def convert_to_iris_message(
message: ChatCompletionMessage, usage: Optional[CompletionUsage], model: str
) -> PyrisMessage:
"""
Convert a ChatCompletionMessage to a PyrisMessage
Convert a ChatCompletionMessage to a PyrisMessage.
Args:
message: The ChatCompletionMessage to convert
usage: Optional token usage information
model: The model name used for the completion
Returns:
PyrisMessage or PyrisAIMessage depending on presence of tool calls
"""
num_input_tokens = getattr(usage, "prompt_tokens", 0)
num_output_tokens = getattr(usage, "completion_tokens", 0)
tokens = TokenUsageDTO(
model=model,
numInputTokens=num_input_tokens,
numOutputTokens=num_output_tokens,
)
token_usage = create_token_usage(usage, model)
current_time = datetime.now()

if message.tool_calls:
return PyrisAIMessage(
tool_calls=[
ToolCallDTO(
id=tc.id,
type=tc.type,
function={
"name": tc.function.name,
"arguments": tc.function.arguments,
},
)
for tc in message.tool_calls
],
tool_calls=create_iris_tool_calls(message.tool_calls),
contents=[TextMessageContentDTO(textContent="")],
sendAt=datetime.now(),
token_usage=tokens,
)
else:
return PyrisMessage(
sender=map_str_to_role(message.role),
contents=[TextMessageContentDTO(textContent=message.content)],
sendAt=datetime.now(),
token_usage=tokens,
sendAt=current_time,
token_usage=token_usage,
)

return PyrisMessage(
sender=map_str_to_role(message.role),
contents=[TextMessageContentDTO(textContent=message.content)],
sendAt=current_time,
token_usage=token_usage,
)


class OpenAIChatModel(ChatModel):
model: str
Expand All @@ -166,44 +219,22 @@ def chat(

for attempt in range(retries):
try:
params = {
"model": self.model,
"messages": messages,
"temperature": arguments.temperature,
"max_tokens": arguments.max_tokens,
}

if arguments.response_format == "JSON":
if self.tools:
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
response_format=ResponseFormatJSONObject(
type="json_object"
),
tools=self.tools,
)
else:
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
response_format=ResponseFormatJSONObject(
type="json_object"
),
)
else:
if self.tools:
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
tools=self.tools,
)
else:
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
)
params["response_format"] = ResponseFormatJSONObject(
type="json_object"
)

if self.tools:
params["tools"] = self.tools

response = client.chat.completions.create(**params)
choice = response.choices[0]
usage = response.usage
model = response.model
Expand Down
9 changes: 9 additions & 0 deletions app/llm/request_handler/basic_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
) -> LanguageModel:
"""
Binds a sequence of tools to the language model.
Args:
tools (Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]]): A sequence of tools to be bound.
Returns:
LanguageModel: The language model with tools bound.
"""
llm = self.llm_manager.get_llm_by_id(self.model_id)
llm.bind_tools(tools)
return llm

0 comments on commit df9eca0

Please sign in to comment.