From df9eca04d2495ca5c538714eeb87c53fbd8b7a62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kaan=20=C3=87ayl=C4=B1?= <38523756+kaancayli@users.noreply.github.com> Date: Tue, 12 Nov 2024 00:35:14 +0100 Subject: [PATCH] chore: some refactoring --- app/llm/external/openai_chat.py | 287 ++++++++++-------- .../request_handler/basic_request_handler.py | 9 + 2 files changed, 168 insertions(+), 128 deletions(-) diff --git a/app/llm/external/openai_chat.py b/app/llm/external/openai_chat.py index 8f1d49ae..1b5e475a 100644 --- a/app/llm/external/openai_chat.py +++ b/app/llm/external/openai_chat.py @@ -20,8 +20,8 @@ from pydantic import Field from pydantic.v1 import BaseModel as LegacyBaseModel -from ...common.message_converters import map_str_to_role, map_role_to_str from app.domain.data.text_message_content_dto import TextMessageContentDTO +from ...common.message_converters import map_role_to_str, map_str_to_role from ...common.pyris_message import PyrisMessage, PyrisAIMessage from ...common.token_usage_dto import TokenUsageDTO from ...domain.data.image_message_content_dto import ImageMessageContentDTO @@ -32,118 +32,171 @@ from ...llm.external.model import ChatModel -def convert_to_open_ai_messages( +def convert_content_to_openai_format(content): + """Convert a single content item to OpenAI format.""" + content_type_mapping = { + ImageMessageContentDTO: lambda c: { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{c.base64}", + "detail": "high", + }, + }, + TextMessageContentDTO: lambda c: {"type": "text", "text": c.text_content}, + JsonMessageContentDTO: lambda c: { + "type": "json_object", + "json_object": c.json_content, + }, + } + + converter = content_type_mapping.get(type(content)) + return converter(content) if converter else None + + +def handle_tool_message(content): + """Handle tool-specific message conversion.""" + if isinstance(content, ToolMessageContentDTO): + return { + "role": "tool", + "content": content.tool_content, + "tool_call_id": content.tool_call_id, + } + return None + + +def create_openai_tool_calls(tool_calls): + """Convert tool calls to OpenAI format.""" + return [ + { + "id": tool.id, + "type": tool.type, + "function": { + "name": tool.function.name, + "arguments": json.dumps(tool.function.arguments), + }, + } + for tool in tool_calls + ] + + +def convert_to_openai_messages( messages: list[PyrisMessage], ) -> list[ChatCompletionMessageParam]: """ - Convert a list of PyrisMessage to a list of ChatCompletionMessageParam + Convert a list of PyrisMessage to a list of ChatCompletionMessageParam. + + Args: + messages: List of PyrisMessage objects to convert + + Returns: + List of messages in OpenAI's format """ openai_messages = [] + for message in messages: + if message.sender == "TOOL": + # Handle tool messages + for content in message.contents: + tool_message = handle_tool_message(content) + if tool_message: + openai_messages.append(tool_message) + continue + + # Handle regular messages openai_content = [] for content in message.contents: - if message.sender == "TOOL": - match content: - case ToolMessageContentDTO(): - openai_messages.append( - { - "role": "tool", - "content": content.tool_content, - "tool_call_id": content.tool_call_id, - } - ) - case _: - pass - else: - match content: - case ImageMessageContentDTO(): - openai_content.append( - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{content.base64}", - "detail": "high", - }, - } - ) - case TextMessageContentDTO(): - openai_content.append( - {"type": "text", "text": content.text_content} - ) - case JsonMessageContentDTO(): - openai_content.append( - { - "type": "json_object", - "json_object": content.json_content, - } - ) - case _: - pass - - if isinstance(message, PyrisAIMessage) and message.tool_calls: - openai_message = { - "role": map_role_to_str(message.sender), - "content": openai_content, - "tool_calls": [ - { - "id": tool.id, - "type": tool.type, - "function": { - "name": tool.function.name, - "arguments": json.dumps(tool.function.arguments), - }, - } - for tool in message.tool_calls - ], - } - else: - openai_message = { - "role": map_role_to_str(message.sender), - "content": openai_content, - } - openai_messages.append(openai_message) + formatted_content = convert_content_to_openai_format(content) + if formatted_content: + openai_content.append(formatted_content) + + # Create the message object + openai_message = { + "role": map_role_to_str(message.sender), + "content": openai_content, + } + + # Add tool calls if present + if isinstance(message, PyrisAIMessage) and message.tool_calls: + openai_message["tool_calls"] = create_openai_tool_calls(message.tool_calls) + + openai_messages.append(openai_message) + return openai_messages +def create_token_usage(usage: Optional[CompletionUsage], model: str) -> TokenUsageDTO: + """ + Create a TokenUsageDTO from CompletionUsage data. + + Args: + usage: Optional CompletionUsage containing token counts + model: The model name used for the completion + + Returns: + TokenUsageDTO with the token usage information + """ + return TokenUsageDTO( + model=model, + numInputTokens=getattr(usage, "prompt_tokens", 0), + numOutputTokens=getattr(usage, "completion_tokens", 0), + ) + + +def create_iris_tool_calls(message_tool_calls) -> list[ToolCallDTO]: + """ + Convert OpenAI tool calls to Iris format. + + Args: + message_tool_calls: List of tool calls from ChatCompletionMessage + + Returns: + List of ToolCallDTO objects + """ + return [ + ToolCallDTO( + id=tc.id, + type=tc.type, + function={ + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + ) + for tc in message_tool_calls + ] + + def convert_to_iris_message( message: ChatCompletionMessage, usage: Optional[CompletionUsage], model: str ) -> PyrisMessage: """ - Convert a ChatCompletionMessage to a PyrisMessage + Convert a ChatCompletionMessage to a PyrisMessage. + + Args: + message: The ChatCompletionMessage to convert + usage: Optional token usage information + model: The model name used for the completion + + Returns: + PyrisMessage or PyrisAIMessage depending on presence of tool calls """ - num_input_tokens = getattr(usage, "prompt_tokens", 0) - num_output_tokens = getattr(usage, "completion_tokens", 0) - tokens = TokenUsageDTO( - model=model, - numInputTokens=num_input_tokens, - numOutputTokens=num_output_tokens, - ) + token_usage = create_token_usage(usage, model) + current_time = datetime.now() if message.tool_calls: return PyrisAIMessage( - tool_calls=[ - ToolCallDTO( - id=tc.id, - type=tc.type, - function={ - "name": tc.function.name, - "arguments": tc.function.arguments, - }, - ) - for tc in message.tool_calls - ], + tool_calls=create_iris_tool_calls(message.tool_calls), contents=[TextMessageContentDTO(textContent="")], - sendAt=datetime.now(), - token_usage=tokens, - ) - else: - return PyrisMessage( - sender=map_str_to_role(message.role), - contents=[TextMessageContentDTO(textContent=message.content)], - sendAt=datetime.now(), - token_usage=tokens, + sendAt=current_time, + token_usage=token_usage, ) + return PyrisMessage( + sender=map_str_to_role(message.role), + contents=[TextMessageContentDTO(textContent=message.content)], + sendAt=current_time, + token_usage=token_usage, + ) + class OpenAIChatModel(ChatModel): model: str @@ -166,44 +219,22 @@ def chat( for attempt in range(retries): try: + params = { + "model": self.model, + "messages": messages, + "temperature": arguments.temperature, + "max_tokens": arguments.max_tokens, + } + if arguments.response_format == "JSON": - if self.tools: - response = client.chat.completions.create( - model=self.model, - messages=messages, - temperature=arguments.temperature, - max_tokens=arguments.max_tokens, - response_format=ResponseFormatJSONObject( - type="json_object" - ), - tools=self.tools, - ) - else: - response = client.chat.completions.create( - model=self.model, - messages=messages, - temperature=arguments.temperature, - max_tokens=arguments.max_tokens, - response_format=ResponseFormatJSONObject( - type="json_object" - ), - ) - else: - if self.tools: - response = client.chat.completions.create( - model=self.model, - messages=messages, - temperature=arguments.temperature, - max_tokens=arguments.max_tokens, - tools=self.tools, - ) - else: - response = client.chat.completions.create( - model=self.model, - messages=messages, - temperature=arguments.temperature, - max_tokens=arguments.max_tokens, - ) + params["response_format"] = ResponseFormatJSONObject( + type="json_object" + ) + + if self.tools: + params["tools"] = self.tools + + response = client.chat.completions.create(**params) choice = response.choices[0] usage = response.usage model = response.model diff --git a/app/llm/request_handler/basic_request_handler.py b/app/llm/request_handler/basic_request_handler.py index 424633ea..dafb31a5 100644 --- a/app/llm/request_handler/basic_request_handler.py +++ b/app/llm/request_handler/basic_request_handler.py @@ -42,6 +42,15 @@ def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], ) -> LanguageModel: + """ + Binds a sequence of tools to the language model. + + Args: + tools (Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]]): A sequence of tools to be bound. + + Returns: + LanguageModel: The language model with tools bound. + """ llm = self.llm_manager.get_llm_by_id(self.model_id) llm.bind_tools(tools) return llm