From e35b407f516e62d041d3f53735b3e789f345733b Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Tue, 29 Oct 2024 18:59:26 +1000 Subject: [PATCH 01/12] initial assistant client draft --- .../models/_openai/_assistants_client.py | 327 ++++++++++++++++++ 1 file changed, 327 insertions(+) create mode 100644 python/packages/autogen-ext/src/autogen_ext/models/_openai/_assistants_client.py diff --git a/python/packages/autogen-ext/src/autogen_ext/models/_openai/_assistants_client.py b/python/packages/autogen-ext/src/autogen_ext/models/_openai/_assistants_client.py new file mode 100644 index 00000000000..4afac803f1f --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/_openai/_assistants_client.py @@ -0,0 +1,327 @@ +# packages/autogen-ext/src/autogen_ext/models/_openai/_openai_assistant_client.py + +import asyncio +import logging +import os +from typing import Any, AsyncIterable, Callable, Dict, List, Mapping, Optional, Sequence, Union + +import aiofiles +from autogen_core.application.logging import EVENT_LOGGER_NAME +from autogen_core.base import CancellationToken +from autogen_core.components import Image +from autogen_core.components.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + LLMMessage, + ModelCapabilities, + RequestUsage, + SystemMessage, + UserMessage, +) +from autogen_core.components.tools import Tool, ToolSchema +from openai import AsyncAssistantEventHandler, AsyncClient +from openai.types.beta.assistant import Assistant +from openai.types.beta.thread import ToolResources +from openai.types.beta.threads import ( + ImageFileContent, + MessageContent, + TextContent, +) +from openai.types.beta.threads import Message as OAIMessage +from openai.types.beta.threads.runs import ( + Run as OAIRun, +) +from openai.types.beta.threads.runs import ( + RunEvent, + RunEventText, +) +from pydantic import BaseModel + +logger = logging.getLogger(EVENT_LOGGER_NAME) + + +# Helper functions to convert between autogen and OpenAI message types +def llm_message_to_oai_content(message: LLMMessage) -> List[MessageContent]: + contents = [] + if isinstance(message, UserMessage): + if isinstance(message.content, str): + contents.append(TextContent(type="text", text={"value": message.content})) + elif isinstance(message.content, list): + for part in message.content: + if isinstance(part, str): + contents.append(TextContent(type="text", text={"value": part})) + elif isinstance(part, Image): + # Convert Image to ImageFileContent + contents.append( + ImageFileContent( + type="image_file", + image_file={"file_id": part.data_uri}, # Assuming data_uri contains file_id + ) + ) + else: + raise ValueError(f"Unsupported content type: {type(part)}") + else: + raise ValueError(f"Unsupported content type: {type(message.content)}") + elif isinstance(message, SystemMessage): + # System messages can be added as assistant messages with role 'system' + contents.append(TextContent(type="text", text={"value": message.content})) + elif isinstance(message, AssistantMessage): + if isinstance(message.content, str): + contents.append(TextContent(type="text", text={"value": message.content})) + elif isinstance(message.content, list): + # Handle function calls if needed + contents.append(TextContent(type="text", text={"value": str(message.content)})) + else: + raise ValueError(f"Unsupported assistant content type: {type(message.content)}") + else: + raise ValueError(f"Unsupported message type: {type(message)}") + return contents + + +class BaseOpenAIAssistantClient(ChatCompletionClient): + def __init__( + self, + client: AsyncClient, + assistant_id: str, + thread_id: Optional[str] = None, + model_capabilities: Optional[ModelCapabilities] = None, + ): + self._client = client + self._assistant_id = assistant_id + self._thread_id = thread_id # If None, a new thread will be created + self._model_capabilities = model_capabilities or ModelCapabilities( + function_calling=True, + vision=False, + json_output=True, + ) + self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + self._assistant: Optional[Assistant] = None + + async def _ensure_thread(self) -> str: + if self._thread_id is None: + thread = await self._client.beta.threads.create() + self._thread_id = thread.id + return self._thread_id + + async def reset_thread(self, cancellation_token: Optional[CancellationToken] = None): + """Reset the thread by deleting all messages in the thread.""" + thread_id = await self._ensure_thread() + all_msgs: List[str] = [] + while True: + if not all_msgs: + list_future = asyncio.ensure_future(self._client.beta.threads.messages.list(thread_id=thread_id)) + else: + list_future = asyncio.ensure_future( + self._client.beta.threads.messages.list(thread_id=thread_id, after=all_msgs[-1]) + ) + if cancellation_token: + cancellation_token.link_future(list_future) + msgs = await list_future + all_msgs.extend(msg.id for msg in msgs.data) + if not msgs.has_next_page(): + break + for msg_id in all_msgs: + delete_future = asyncio.ensure_future( + self._client.beta.threads.messages.delete(thread_id=thread_id, message_id=msg_id) + ) + if cancellation_token: + cancellation_token.link_future(delete_future) + await delete_future + + async def upload_file_to_code_interpreter( + self, file_path: str, cancellation_token: Optional[CancellationToken] = None + ): + """Upload a file to the code interpreter and update the thread.""" + thread_id = await self._ensure_thread() + # Get the file content + async with aiofiles.open(file_path, mode="rb") as f: + read_future = asyncio.ensure_future(f.read()) + if cancellation_token: + cancellation_token.link_future(read_future) + file_content = await read_future + file_name = os.path.basename(file_path) + # Upload the file + file_future = asyncio.ensure_future( + self._client.files.create(file=(file_name, file_content), purpose="assistants") + ) + if cancellation_token: + cancellation_token.link_future(file_future) + file = await file_future + # Get existing file ids from tool resources + retrieve_future = asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=thread_id)) + if cancellation_token: + cancellation_token.link_future(retrieve_future) + thread = await retrieve_future + tool_resources: ToolResources = thread.tool_resources if thread.tool_resources else ToolResources() + if tool_resources.code_interpreter and tool_resources.code_interpreter.file_ids: + file_ids = tool_resources.code_interpreter.file_ids + [file.id] + else: + file_ids = [file.id] + # Update thread with new file + update_future = asyncio.ensure_future( + self._client.beta.threads.update( + thread_id=thread_id, + tool_resources={ + "code_interpreter": {"file_ids": file_ids}, + }, + ) + ) + if cancellation_token: + cancellation_token.link_future(update_future) + await update_future + + async def upload_file_to_vector_store( + self, file_path: str, vector_store_id: str, cancellation_token: Optional[CancellationToken] = None + ): + """Upload a file to the vector store.""" + # Get the file content + async with aiofiles.open(file_path, mode="rb") as f: + read_future = asyncio.ensure_future(f.read()) + if cancellation_token: + cancellation_token.link_future(read_future) + file_content = await read_future + file_name = os.path.basename(file_path) + # Upload the file + upload_future = asyncio.ensure_future( + self._client.beta.vector_stores.file_batches.upload_and_poll( + vector_store_id=vector_store_id, + files=[(file_name, file_content)], + ) + ) + if cancellation_token: + cancellation_token.link_future(upload_future) + await upload_future + + async def create( + self, + messages: Sequence[LLMMessage], + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + ) -> CreateResult: + thread_id = await self._ensure_thread() + # Send messages to the thread + for message in messages: + contents = llm_message_to_oai_content(message) + await self._client.beta.threads.messages.create( + thread_id=thread_id, + content=contents, + role=message.__class__.__name__.lower(), # 'user', 'assistant', etc. + metadata={"sender": message.source} if hasattr(message, "source") else {}, + ) + # Run the assistant + run_future = asyncio.ensure_future( + self._client.beta.threads.create_and_run( + assistant_id=self._assistant_id, + thread={"id": thread_id}, + ) + ) + if cancellation_token is not None: + cancellation_token.link_future(run_future) + _: OAIRun = await run_future + # Get the last message + messages_result = await self._client.beta.threads.messages.list(thread_id=thread_id, order="desc", limit=1) + last_message = messages_result.data[0] + # Extract content + content = "" + for part in last_message.content: + if part.type == "text": + content += part.text.value + # Handle other content types if necessary + # Create usage data (Note: OpenAI Assistant API might not provide token usage directly) + usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + result = CreateResult( + finish_reason="stop", + content=content, + usage=usage, + cached=False, + ) + return result + + async def create_stream( + self, + messages: Sequence[LLMMessage], + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + assistant_event_handler_factory: Optional[Callable[[], AsyncAssistantEventHandler]] = None, + ) -> AsyncIterable[Union[str, CreateResult]]: + thread_id = await self._ensure_thread() + # Send messages to the thread + for message in messages: + contents = llm_message_to_oai_content(message) + await self._client.beta.threads.messages.create( + thread_id=thread_id, + content=contents, + role=message.__class__.__name__.lower(), + metadata={"sender": message.source} if hasattr(message, "source") else {}, + ) + # Run the assistant with streaming + if assistant_event_handler_factory: + event_handler = assistant_event_handler_factory() + else: + event_handler = AsyncAssistantEventHandler() # default handler + stream_manager = self._client.beta.threads.create_and_run_stream( + assistant_id=self._assistant_id, + thread={"id": thread_id}, + event_handler=event_handler, + ) + stream = stream_manager.stream() + if cancellation_token is not None: + cancellation_token.link_future(asyncio.ensure_future(stream_manager.wait_until_done())) + content = "" + async for event in stream: + if isinstance(event, RunEventText): + content += event.text.value + yield event.text.value + # Handle other event types if necessary + # After the stream is done, create the final result + usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + result = CreateResult( + finish_reason="stop", + content=content, + usage=usage, + cached=False, + ) + yield result + + def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + # Implement token counting logic if possible + raise NotImplementedError("Token counting is not supported for OpenAI Assistant API Client") + + def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + # Implement remaining tokens logic if possible + raise NotImplementedError("Remaining tokens are not supported for OpenAI Assistant API Client") + + def actual_usage(self) -> RequestUsage: + return self._actual_usage + + def total_usage(self) -> RequestUsage: + return self._total_usage + + @property + def capabilities(self) -> ModelCapabilities: + return self._model_capabilities + + +class OpenAIAssistantClient(BaseOpenAIAssistantClient): + def __init__( + self, + client: AsyncClient, + assistant_id: str, + thread_id: Optional[str] = None, + model_capabilities: Optional[ModelCapabilities] = None, + ): + super().__init__(client, assistant_id, thread_id, model_capabilities) + + @classmethod + def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient: + client = AsyncClient(**config.get("client_kwargs", {})) + assistant_id = config["assistant_id"] + thread_id = config.get("thread_id") + model_capabilities = config.get("model_capabilities") + return cls(client, assistant_id, thread_id, model_capabilities) From 29ddbc1ac98c20b85a4a87160a8b884a732c222a Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Tue, 29 Oct 2024 19:07:06 +1000 Subject: [PATCH 02/12] expose assistants client --- python/packages/autogen-ext/src/autogen_ext/models/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/__init__.py index e7b2b76ae36..993343ed657 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/__init__.py @@ -1,3 +1,4 @@ +from ._openai._assistants_client import OpenAIAssistantClient from ._openai._openai_client import ( AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient, @@ -6,4 +7,5 @@ __all__ = [ "AzureOpenAIChatCompletionClient", "OpenAIChatCompletionClient", + "OpenAIAssistantClient", ] From c75cb3162ad62d8da8b76d1ea9c357569d004078 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Mon, 11 Nov 2024 16:51:43 +1000 Subject: [PATCH 03/12] initial openai assistant agentchat draft --- .../src/autogen_agentchat/agents/__init__.py | 3 + .../agents/_openai_assistant_agent.py | 212 ++++++++++++ .../src/autogen_ext/models/__init__.py | 2 - .../models/_openai/_assistants_client.py | 327 ------------------ 4 files changed, 215 insertions(+), 329 deletions(-) create mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py delete mode 100644 python/packages/autogen-ext/src/autogen_ext/models/_openai/_assistants_client.py diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py index 2f32588604e..e7744719c5f 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py @@ -2,6 +2,7 @@ from ._base_chat_agent import BaseChatAgent from ._code_executor_agent import CodeExecutorAgent from ._coding_assistant_agent import CodingAssistantAgent +from ._openai_assistant_agent import OpenAIAssistantChatAgent, OpenAIAssistantEventHandler from ._tool_use_assistant_agent import ToolUseAssistantAgent __all__ = [ @@ -9,5 +10,7 @@ "AssistantAgent", "CodeExecutorAgent", "CodingAssistantAgent", + "OpenAIAssistantChatAgent", + "OpenAIAssistantEventHandler", "ToolUseAssistantAgent", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py new file mode 100644 index 00000000000..d8bad2cd69c --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py @@ -0,0 +1,212 @@ +import asyncio +import os +from typing import Callable, List, Optional, Sequence + +import aiofiles +from autogen_core.base import CancellationToken +from autogen_agentchat.messages import ChatMessage, TextMessage +from openai import AsyncAssistantEventHandler, AsyncClient +from openai.types.beta.thread import ToolResources, ToolResourcesFileSearch +from openai.types.beta.threads import Message, Text, TextDelta +from openai.types.beta.threads.runs import RunStep, RunStepDelta +from typing_extensions import override + +from ._base_chat_agent import BaseChatAgent + + +class OpenAIAssistantChatAgent(BaseChatAgent): + """An agent implementation that uses the OpenAI Assistant API to generate responses.""" + + def __init__( + self, + name: str, + description: str, + client: AsyncClient, + model: str, + instructions: str, + tools: Optional[List[dict]] = None, + ) -> None: + super().__init__(name, description) + if tools is None: + tools = [] + self._client = client + self._assistant = None + self._thread = None + self._model = model + self._instructions = instructions + self._tools = tools + + async def _ensure_initialized(self): + """Ensure assistant and thread are created.""" + if self._assistant is None: + self._assistant = await self._client.beta.assistants.create( + model=self._model, + description=self.description, + instructions=self._instructions, + tools=self._tools, + ) + + if self._thread is None: + self._thread = await self._client.beta.threads.create() + + @property + def _assistant_id(self) -> str: + if self._assistant is None: + raise ValueError("Assistant not initialized") + return self._assistant.id + + @property + def _thread_id(self) -> str: + if self._thread is None: + raise ValueError("Thread not initialized") + return self._thread.id + + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: + """Handle incoming messages and return a response message.""" + await self._ensure_initialized() + + for message in messages: + content = message.content.strip().lower() + if content == "reset": + await self.on_reset(cancellation_token) + elif content.startswith("upload_code "): + file_path = message.content[len("upload_code ") :].strip() + await self.on_upload_for_code_interpreter(file_path, cancellation_token) + elif content.startswith("upload_search "): + file_path = message.content[len("upload_search ") :].strip() + await self.on_upload_for_file_search(file_path, cancellation_token) + else: + await self.handle_text_message(message.content, cancellation_token) + + # Create and start a run + run = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.runs.create( + thread_id=self._thread_id, + assistant_id=self._assistant_id, + ) + ) + ) + + # Wait for run completion by polling + while run.status == "queued" or run.status == "in_progress": + run = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.runs.retrieve( + thread_id=self._thread_id, + run_id=run.id, + ) + ) + ) + await asyncio.sleep(0.5) + + # Get messages after run completion + messages = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.messages.list( + thread_id=self._thread_id, + order="desc", + limit=1 + ) + ) + ) + + if not messages.data: + raise ValueError("No messages received from assistant") + + # Get the last message's content + last_message = messages.data[0] + if not last_message.content: + raise ValueError(f"No content in the last message: {last_message}") + + # Extract text content + text_content = [content for content in last_message.content if content.type == "text"] + if not text_content: + raise ValueError(f"Expected text content in the last message: {last_message.content}") + + # Return the assistant's response as a ChatMessage + return TextMessage(source=self.name, content=text_content[0].text.value) + + async def handle_text_message(self, content: str, cancellation_token: CancellationToken) -> None: + """Handle regular text messages by adding them to the thread.""" + await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.messages.create( + thread_id=self._thread_id, + content=content, + role="user", + ) + ) + ) + + async def on_reset(self, cancellation_token: CancellationToken): + """Handle reset command by deleting all messages in the thread.""" + # Retrieve all message IDs in the thread + all_msgs = [] + after = None + while True: + msgs = await cancellation_token.link_future( + asyncio.ensure_future(self._client.beta.threads.messages.list(self._thread_id, after=after)) + ) + for msg in msgs.data: + all_msgs.append(msg.id) + after = msg.id + if not msgs.has_next_page(): + break + + # Delete all messages + for msg_id in all_msgs: + status = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) + ) + ) + assert status.deleted is True + + async def on_upload_for_code_interpreter(self, file_path: str, cancellation_token: CancellationToken): + """Handle file uploads for the code interpreter.""" + # Read the file content + async with aiofiles.open(file_path, mode="rb") as f: + file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) + file_name = os.path.basename(file_path) + + # Upload the file + file = await cancellation_token.link_future( + asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants")) + ) + + # Update thread with the new file + thread = await cancellation_token.link_future( + asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) + ) + tool_resources: ToolResources = thread.tool_resources or ToolResources() + code_interpreter = tool_resources.code_interpreter or ToolResourcesFileSearch() + file_ids = code_interpreter.file_ids or [] + file_ids.append(file.id) + tool_resources.code_interpreter = ToolResourcesFileSearch(file_ids=file_ids) + + await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.update( + thread_id=self._thread_id, + tool_resources=tool_resources, + ) + ) + ) + + async def on_upload_for_file_search(self, file_path: str, cancellation_token: CancellationToken): + """Handle file uploads for file search.""" + # Read the file content + async with aiofiles.open(file_path, mode="rb") as f: + file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) + file_name = os.path.basename(file_path) + + # Upload the file to the vector store + await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.vector_stores.file_batches.upload_and_poll( + vector_store_id=self._vector_store_id, + files=[(file_name, file_content)], + ) + ) + ) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/__init__.py index 993343ed657..e7b2b76ae36 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/__init__.py @@ -1,4 +1,3 @@ -from ._openai._assistants_client import OpenAIAssistantClient from ._openai._openai_client import ( AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient, @@ -7,5 +6,4 @@ __all__ = [ "AzureOpenAIChatCompletionClient", "OpenAIChatCompletionClient", - "OpenAIAssistantClient", ] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/_openai/_assistants_client.py b/python/packages/autogen-ext/src/autogen_ext/models/_openai/_assistants_client.py deleted file mode 100644 index 4afac803f1f..00000000000 --- a/python/packages/autogen-ext/src/autogen_ext/models/_openai/_assistants_client.py +++ /dev/null @@ -1,327 +0,0 @@ -# packages/autogen-ext/src/autogen_ext/models/_openai/_openai_assistant_client.py - -import asyncio -import logging -import os -from typing import Any, AsyncIterable, Callable, Dict, List, Mapping, Optional, Sequence, Union - -import aiofiles -from autogen_core.application.logging import EVENT_LOGGER_NAME -from autogen_core.base import CancellationToken -from autogen_core.components import Image -from autogen_core.components.models import ( - AssistantMessage, - ChatCompletionClient, - CreateResult, - LLMMessage, - ModelCapabilities, - RequestUsage, - SystemMessage, - UserMessage, -) -from autogen_core.components.tools import Tool, ToolSchema -from openai import AsyncAssistantEventHandler, AsyncClient -from openai.types.beta.assistant import Assistant -from openai.types.beta.thread import ToolResources -from openai.types.beta.threads import ( - ImageFileContent, - MessageContent, - TextContent, -) -from openai.types.beta.threads import Message as OAIMessage -from openai.types.beta.threads.runs import ( - Run as OAIRun, -) -from openai.types.beta.threads.runs import ( - RunEvent, - RunEventText, -) -from pydantic import BaseModel - -logger = logging.getLogger(EVENT_LOGGER_NAME) - - -# Helper functions to convert between autogen and OpenAI message types -def llm_message_to_oai_content(message: LLMMessage) -> List[MessageContent]: - contents = [] - if isinstance(message, UserMessage): - if isinstance(message.content, str): - contents.append(TextContent(type="text", text={"value": message.content})) - elif isinstance(message.content, list): - for part in message.content: - if isinstance(part, str): - contents.append(TextContent(type="text", text={"value": part})) - elif isinstance(part, Image): - # Convert Image to ImageFileContent - contents.append( - ImageFileContent( - type="image_file", - image_file={"file_id": part.data_uri}, # Assuming data_uri contains file_id - ) - ) - else: - raise ValueError(f"Unsupported content type: {type(part)}") - else: - raise ValueError(f"Unsupported content type: {type(message.content)}") - elif isinstance(message, SystemMessage): - # System messages can be added as assistant messages with role 'system' - contents.append(TextContent(type="text", text={"value": message.content})) - elif isinstance(message, AssistantMessage): - if isinstance(message.content, str): - contents.append(TextContent(type="text", text={"value": message.content})) - elif isinstance(message.content, list): - # Handle function calls if needed - contents.append(TextContent(type="text", text={"value": str(message.content)})) - else: - raise ValueError(f"Unsupported assistant content type: {type(message.content)}") - else: - raise ValueError(f"Unsupported message type: {type(message)}") - return contents - - -class BaseOpenAIAssistantClient(ChatCompletionClient): - def __init__( - self, - client: AsyncClient, - assistant_id: str, - thread_id: Optional[str] = None, - model_capabilities: Optional[ModelCapabilities] = None, - ): - self._client = client - self._assistant_id = assistant_id - self._thread_id = thread_id # If None, a new thread will be created - self._model_capabilities = model_capabilities or ModelCapabilities( - function_calling=True, - vision=False, - json_output=True, - ) - self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) - self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) - self._assistant: Optional[Assistant] = None - - async def _ensure_thread(self) -> str: - if self._thread_id is None: - thread = await self._client.beta.threads.create() - self._thread_id = thread.id - return self._thread_id - - async def reset_thread(self, cancellation_token: Optional[CancellationToken] = None): - """Reset the thread by deleting all messages in the thread.""" - thread_id = await self._ensure_thread() - all_msgs: List[str] = [] - while True: - if not all_msgs: - list_future = asyncio.ensure_future(self._client.beta.threads.messages.list(thread_id=thread_id)) - else: - list_future = asyncio.ensure_future( - self._client.beta.threads.messages.list(thread_id=thread_id, after=all_msgs[-1]) - ) - if cancellation_token: - cancellation_token.link_future(list_future) - msgs = await list_future - all_msgs.extend(msg.id for msg in msgs.data) - if not msgs.has_next_page(): - break - for msg_id in all_msgs: - delete_future = asyncio.ensure_future( - self._client.beta.threads.messages.delete(thread_id=thread_id, message_id=msg_id) - ) - if cancellation_token: - cancellation_token.link_future(delete_future) - await delete_future - - async def upload_file_to_code_interpreter( - self, file_path: str, cancellation_token: Optional[CancellationToken] = None - ): - """Upload a file to the code interpreter and update the thread.""" - thread_id = await self._ensure_thread() - # Get the file content - async with aiofiles.open(file_path, mode="rb") as f: - read_future = asyncio.ensure_future(f.read()) - if cancellation_token: - cancellation_token.link_future(read_future) - file_content = await read_future - file_name = os.path.basename(file_path) - # Upload the file - file_future = asyncio.ensure_future( - self._client.files.create(file=(file_name, file_content), purpose="assistants") - ) - if cancellation_token: - cancellation_token.link_future(file_future) - file = await file_future - # Get existing file ids from tool resources - retrieve_future = asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=thread_id)) - if cancellation_token: - cancellation_token.link_future(retrieve_future) - thread = await retrieve_future - tool_resources: ToolResources = thread.tool_resources if thread.tool_resources else ToolResources() - if tool_resources.code_interpreter and tool_resources.code_interpreter.file_ids: - file_ids = tool_resources.code_interpreter.file_ids + [file.id] - else: - file_ids = [file.id] - # Update thread with new file - update_future = asyncio.ensure_future( - self._client.beta.threads.update( - thread_id=thread_id, - tool_resources={ - "code_interpreter": {"file_ids": file_ids}, - }, - ) - ) - if cancellation_token: - cancellation_token.link_future(update_future) - await update_future - - async def upload_file_to_vector_store( - self, file_path: str, vector_store_id: str, cancellation_token: Optional[CancellationToken] = None - ): - """Upload a file to the vector store.""" - # Get the file content - async with aiofiles.open(file_path, mode="rb") as f: - read_future = asyncio.ensure_future(f.read()) - if cancellation_token: - cancellation_token.link_future(read_future) - file_content = await read_future - file_name = os.path.basename(file_path) - # Upload the file - upload_future = asyncio.ensure_future( - self._client.beta.vector_stores.file_batches.upload_and_poll( - vector_store_id=vector_store_id, - files=[(file_name, file_content)], - ) - ) - if cancellation_token: - cancellation_token.link_future(upload_future) - await upload_future - - async def create( - self, - messages: Sequence[LLMMessage], - tools: Sequence[Tool | ToolSchema] = [], - json_output: Optional[bool] = None, - extra_create_args: Mapping[str, Any] = {}, - cancellation_token: Optional[CancellationToken] = None, - ) -> CreateResult: - thread_id = await self._ensure_thread() - # Send messages to the thread - for message in messages: - contents = llm_message_to_oai_content(message) - await self._client.beta.threads.messages.create( - thread_id=thread_id, - content=contents, - role=message.__class__.__name__.lower(), # 'user', 'assistant', etc. - metadata={"sender": message.source} if hasattr(message, "source") else {}, - ) - # Run the assistant - run_future = asyncio.ensure_future( - self._client.beta.threads.create_and_run( - assistant_id=self._assistant_id, - thread={"id": thread_id}, - ) - ) - if cancellation_token is not None: - cancellation_token.link_future(run_future) - _: OAIRun = await run_future - # Get the last message - messages_result = await self._client.beta.threads.messages.list(thread_id=thread_id, order="desc", limit=1) - last_message = messages_result.data[0] - # Extract content - content = "" - for part in last_message.content: - if part.type == "text": - content += part.text.value - # Handle other content types if necessary - # Create usage data (Note: OpenAI Assistant API might not provide token usage directly) - usage = RequestUsage(prompt_tokens=0, completion_tokens=0) - result = CreateResult( - finish_reason="stop", - content=content, - usage=usage, - cached=False, - ) - return result - - async def create_stream( - self, - messages: Sequence[LLMMessage], - tools: Sequence[Tool | ToolSchema] = [], - json_output: Optional[bool] = None, - extra_create_args: Mapping[str, Any] = {}, - cancellation_token: Optional[CancellationToken] = None, - assistant_event_handler_factory: Optional[Callable[[], AsyncAssistantEventHandler]] = None, - ) -> AsyncIterable[Union[str, CreateResult]]: - thread_id = await self._ensure_thread() - # Send messages to the thread - for message in messages: - contents = llm_message_to_oai_content(message) - await self._client.beta.threads.messages.create( - thread_id=thread_id, - content=contents, - role=message.__class__.__name__.lower(), - metadata={"sender": message.source} if hasattr(message, "source") else {}, - ) - # Run the assistant with streaming - if assistant_event_handler_factory: - event_handler = assistant_event_handler_factory() - else: - event_handler = AsyncAssistantEventHandler() # default handler - stream_manager = self._client.beta.threads.create_and_run_stream( - assistant_id=self._assistant_id, - thread={"id": thread_id}, - event_handler=event_handler, - ) - stream = stream_manager.stream() - if cancellation_token is not None: - cancellation_token.link_future(asyncio.ensure_future(stream_manager.wait_until_done())) - content = "" - async for event in stream: - if isinstance(event, RunEventText): - content += event.text.value - yield event.text.value - # Handle other event types if necessary - # After the stream is done, create the final result - usage = RequestUsage(prompt_tokens=0, completion_tokens=0) - result = CreateResult( - finish_reason="stop", - content=content, - usage=usage, - cached=False, - ) - yield result - - def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: - # Implement token counting logic if possible - raise NotImplementedError("Token counting is not supported for OpenAI Assistant API Client") - - def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: - # Implement remaining tokens logic if possible - raise NotImplementedError("Remaining tokens are not supported for OpenAI Assistant API Client") - - def actual_usage(self) -> RequestUsage: - return self._actual_usage - - def total_usage(self) -> RequestUsage: - return self._total_usage - - @property - def capabilities(self) -> ModelCapabilities: - return self._model_capabilities - - -class OpenAIAssistantClient(BaseOpenAIAssistantClient): - def __init__( - self, - client: AsyncClient, - assistant_id: str, - thread_id: Optional[str] = None, - model_capabilities: Optional[ModelCapabilities] = None, - ): - super().__init__(client, assistant_id, thread_id, model_capabilities) - - @classmethod - def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient: - client = AsyncClient(**config.get("client_kwargs", {})) - assistant_id = config["assistant_id"] - thread_id = config.get("thread_id") - model_capabilities = config.get("model_capabilities") - return cls(client, assistant_id, thread_id, model_capabilities) From 41219afa5589ecd09d4c14184c222339e7c8a7ea Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Fri, 15 Nov 2024 15:09:37 +1000 Subject: [PATCH 04/12] update file search --- .../src/autogen_agentchat/agents/__init__.py | 3 +- .../agents/_openai_assistant_agent.py | 106 ++++++++++++------ 2 files changed, 73 insertions(+), 36 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py index 3c11745b947..0f10859f80a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py @@ -2,8 +2,8 @@ from ._base_chat_agent import BaseChatAgent from ._code_executor_agent import CodeExecutorAgent from ._coding_assistant_agent import CodingAssistantAgent -from ._openai_assistant_agent import OpenAIAssistantChatAgent, OpenAIAssistantEventHandler from ._society_of_mind_agent import SocietyOfMindAgent +from ._openai_assistant_agent import OpenAIAssistantChatAgent from ._tool_use_assistant_agent import ToolUseAssistantAgent __all__ = [ @@ -13,7 +13,6 @@ "CodeExecutorAgent", "CodingAssistantAgent", "OpenAIAssistantChatAgent", - "OpenAIAssistantEventHandler", "ToolUseAssistantAgent", "SocietyOfMindAgent", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py index d8bad2cd69c..bc18248141c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py @@ -1,15 +1,15 @@ import asyncio import os -from typing import Callable, List, Optional, Sequence +from typing import Iterable, Optional, Sequence import aiofiles from autogen_core.base import CancellationToken from autogen_agentchat.messages import ChatMessage, TextMessage -from openai import AsyncAssistantEventHandler, AsyncClient +from openai import AsyncClient from openai.types.beta.thread import ToolResources, ToolResourcesFileSearch -from openai.types.beta.threads import Message, Text, TextDelta -from openai.types.beta.threads.runs import RunStep, RunStepDelta -from typing_extensions import override +from openai.types.beta.threads import Run +from openai.types.beta.assistant_tool_param import AssistantToolParam +from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam from ._base_chat_agent import BaseChatAgent @@ -24,7 +24,13 @@ def __init__( client: AsyncClient, model: str, instructions: str, - tools: Optional[List[dict]] = None, + tools: Optional[Iterable[AssistantToolParam]] = None, + assistant_id: Optional[str] = None, + metadata: Optional[object] = None, + response_format: Optional[AssistantResponseFormatOptionParam] = None, + temperature: Optional[float] = None, + tool_resources: Optional[dict] = None, + top_p: Optional[float] = None, ) -> None: super().__init__(name, description) if tools is None: @@ -35,27 +41,42 @@ def __init__( self._model = model self._instructions = instructions self._tools = tools + self._assistant_id = assistant_id + self._metadata = metadata + self._response_format = response_format + self._temperature = temperature + self._tool_resources = tool_resources + self._top_p = top_p + self._vector_store_id = None async def _ensure_initialized(self): """Ensure assistant and thread are created.""" if self._assistant is None: - self._assistant = await self._client.beta.assistants.create( - model=self._model, - description=self.description, - instructions=self._instructions, - tools=self._tools, - ) - + if self._assistant_id: + self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id) + else: + self._assistant = await self._client.beta.assistants.create( + model=self._model, + description=self.description, + instructions=self._instructions, + tools=self._tools, + metadata=self._metadata, + response_format=self._response_format, + temperature=self._temperature, + tool_resources=self._tool_resources, + top_p=self._top_p, + ) + if self._thread is None: self._thread = await self._client.beta.threads.create() @property - def _assistant_id(self) -> str: + def _get_assistant_id(self) -> str: if self._assistant is None: raise ValueError("Assistant not initialized") return self._assistant.id - @property + @property def _thread_id(self) -> str: if self._thread is None: raise ValueError("Thread not initialized") @@ -65,25 +86,16 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: """Handle incoming messages and return a response message.""" await self._ensure_initialized() - for message in messages: - content = message.content.strip().lower() - if content == "reset": - await self.on_reset(cancellation_token) - elif content.startswith("upload_code "): - file_path = message.content[len("upload_code ") :].strip() - await self.on_upload_for_code_interpreter(file_path, cancellation_token) - elif content.startswith("upload_search "): - file_path = message.content[len("upload_search ") :].strip() - await self.on_upload_for_file_search(file_path, cancellation_token) - else: - await self.handle_text_message(message.content, cancellation_token) + # Only process the last message and rely on the thread for context + message = messages[-1] + await self.handle_text_message(message.content, cancellation_token) # Create and start a run - run = await cancellation_token.link_future( + run: Run = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.runs.create( thread_id=self._thread_id, - assistant_id=self._assistant_id, + assistant_id=self._get_assistant_id, ) ) ) @@ -100,17 +112,18 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: ) await asyncio.sleep(0.5) + if run.status == "failed": + raise ValueError(f"Run failed: {run.last_error}") + # Get messages after run completion messages = await cancellation_token.link_future( asyncio.ensure_future( - self._client.beta.threads.messages.list( - thread_id=self._thread_id, - order="desc", - limit=1 - ) + self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) ) ) + breakpoint() + if not messages.data: raise ValueError("No messages received from assistant") @@ -196,6 +209,31 @@ async def on_upload_for_code_interpreter(self, file_path: str, cancellation_toke async def on_upload_for_file_search(self, file_path: str, cancellation_token: CancellationToken): """Handle file uploads for file search.""" + await self._ensure_initialized() + + # Check if file_search is enabled in tools + if not any(tool.get("type") == "file_search" for tool in self._tools): + raise ValueError( + "File search is not enabled for this assistant. Add a file_search tool when creating the assistant." + ) + + # Create vector store if not already created + if self._vector_store_id is None: + vector_store = await cancellation_token.link_future( + asyncio.ensure_future(self._client.beta.vector_stores.create()) + ) + self._vector_store_id = vector_store.id + + # Update assistant with vector store ID + await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.assistants.update( + assistant_id=self._get_assistant_id, + tool_resources={"file_search": {"vector_store_ids": [self._vector_store_id]}}, + ) + ) + ) + # Read the file content async with aiofiles.open(file_path, mode="rb") as f: file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) From 19b1fba3e1fc09954b3154ea112ccbabe893658a Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Sat, 16 Nov 2024 13:17:46 +1000 Subject: [PATCH 05/12] add delete methods and fix typing --- .../src/autogen_agentchat/agents/__init__.py | 4 +- .../agents/_openai_assistant_agent.py | 204 +++++++++++++----- 2 files changed, 148 insertions(+), 60 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py index 0f10859f80a..600842010af 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py @@ -2,8 +2,8 @@ from ._base_chat_agent import BaseChatAgent from ._code_executor_agent import CodeExecutorAgent from ._coding_assistant_agent import CodingAssistantAgent +from ._openai_assistant_agent import OpenAIAssistantAgent from ._society_of_mind_agent import SocietyOfMindAgent -from ._openai_assistant_agent import OpenAIAssistantChatAgent from ._tool_use_assistant_agent import ToolUseAssistantAgent __all__ = [ @@ -12,7 +12,7 @@ "Handoff", "CodeExecutorAgent", "CodingAssistantAgent", - "OpenAIAssistantChatAgent", + "OpenAIAssistantAgent", "ToolUseAssistantAgent", "SocietyOfMindAgent", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py index bc18248141c..972c0d55b2f 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py @@ -1,20 +1,54 @@ import asyncio +import logging import os -from typing import Iterable, Optional, Sequence +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, cast import aiofiles from autogen_core.base import CancellationToken -from autogen_agentchat.messages import ChatMessage, TextMessage -from openai import AsyncClient -from openai.types.beta.thread import ToolResources, ToolResourcesFileSearch -from openai.types.beta.threads import Run -from openai.types.beta.assistant_tool_param import AssistantToolParam +from autogen_core.components.tools import Tool +from openai import NOT_GIVEN, AsyncClient, NotGiven +from openai.pagination import AsyncCursorPage +from openai.types import FileObject +from openai.types.beta import thread_update_params +from openai.types.beta.assistant import Assistant from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam +from openai.types.beta.assistant_tool_param import AssistantToolParam +from openai.types.beta.function_tool_param import FunctionToolParam +from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter +from openai.types.beta.threads import Message, MessageDeleted, Run +from openai.types.beta.vector_store import VectorStore +from openai.types.shared_params.function_definition import FunctionDefinition + +from autogen_agentchat.messages import ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage +from .. import EVENT_LOGGER_NAME +from ..base import Response from ._base_chat_agent import BaseChatAgent +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + + +def _convert_tool_to_function_param(tool: Tool) -> FunctionToolParam: + """Convert an autogen Tool to an OpenAI Assistant function tool parameter.""" + schema = tool.schema + parameters: Dict[str, object] = {} + if "parameters" in schema: + parameters = { + "type": schema["parameters"]["type"], + "properties": schema["parameters"]["properties"], + } + if "required" in schema["parameters"]: + parameters["required"] = schema["parameters"]["required"] -class OpenAIAssistantChatAgent(BaseChatAgent): + function_def = FunctionDefinition( + name=schema["name"], + description=schema.get("description", ""), + parameters=parameters, + ) + return FunctionToolParam(type="function", function=function_def) + + +class OpenAIAssistantAgent(BaseChatAgent): """An agent implementation that uses the OpenAI Assistant API to generate responses.""" def __init__( @@ -24,32 +58,42 @@ def __init__( client: AsyncClient, model: str, instructions: str, - tools: Optional[Iterable[AssistantToolParam]] = None, + tools: Optional[Iterable[AssistantToolParam | Tool]] = None, assistant_id: Optional[str] = None, metadata: Optional[object] = None, response_format: Optional[AssistantResponseFormatOptionParam] = None, temperature: Optional[float] = None, - tool_resources: Optional[dict] = None, + tool_resources: Optional[ToolResources] = None, top_p: Optional[float] = None, ) -> None: super().__init__(name, description) if tools is None: tools = [] + + # Convert autogen Tools to OpenAI Assistant tools + converted_tools: List[AssistantToolParam] = [] + for tool in tools: + if isinstance(tool, Tool): + converted_tools.append(_convert_tool_to_function_param(tool)) + else: + converted_tools.append(tool) + self._client = client - self._assistant = None - self._thread = None + self._assistant: Optional[Assistant] = None + self._thread: Optional[Thread] = None self._model = model self._instructions = instructions - self._tools = tools + self._tools = converted_tools self._assistant_id = assistant_id self._metadata = metadata self._response_format = response_format self._temperature = temperature self._tool_resources = tool_resources self._top_p = top_p - self._vector_store_id = None + self._vector_store_id: Optional[str] = None + self._uploaded_file_ids: List[str] = [] - async def _ensure_initialized(self): + async def _ensure_initialized(self) -> None: """Ensure assistant and thread are created.""" if self._assistant is None: if self._assistant_id: @@ -61,9 +105,9 @@ async def _ensure_initialized(self): instructions=self._instructions, tools=self._tools, metadata=self._metadata, - response_format=self._response_format, + response_format=self._response_format if self._response_format else NOT_GIVEN, # type: ignore temperature=self._temperature, - tool_resources=self._tool_resources, + tool_resources=self._tool_resources if self._tool_resources else NOT_GIVEN, # type: ignore top_p=self._top_p, ) @@ -82,13 +126,16 @@ def _thread_id(self) -> str: raise ValueError("Thread not initialized") return self._thread.id - async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: - """Handle incoming messages and return a response message.""" + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: + """Handle incoming messages and return a response.""" await self._ensure_initialized() - # Only process the last message and rely on the thread for context - message = messages[-1] - await self.handle_text_message(message.content, cancellation_token) + # Process all messages in sequence + for message in messages: + if isinstance(message, (TextMessage, MultiModalMessage)): + await self.handle_text_message(str(message.content), cancellation_token) + elif isinstance(message, (StopMessage, HandoffMessage)): + await self.handle_text_message(message.content, cancellation_token) # Create and start a run run: Run = await cancellation_token.link_future( @@ -116,19 +163,17 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: raise ValueError(f"Run failed: {run.last_error}") # Get messages after run completion - messages = await cancellation_token.link_future( + assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) ) ) - breakpoint() - - if not messages.data: + if not assistant_messages.data: raise ValueError("No messages received from assistant") # Get the last message's content - last_message = messages.data[0] + last_message = assistant_messages.data[0] if not last_message.content: raise ValueError(f"No content in the last message: {last_message}") @@ -137,8 +182,9 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: if not text_content: raise ValueError(f"Expected text content in the last message: {last_message.content}") - # Return the assistant's response as a ChatMessage - return TextMessage(source=self.name, content=text_content[0].text.value) + # Return the assistant's response as a Response + chat_message = TextMessage(source=self.name, content=text_content[0].text.value) + return Response(chat_message=chat_message) async def handle_text_message(self, content: str, cancellation_token: CancellationToken) -> None: """Handle regular text messages by adding them to the thread.""" @@ -152,13 +198,13 @@ async def handle_text_message(self, content: str, cancellation_token: Cancellati ) ) - async def on_reset(self, cancellation_token: CancellationToken): + async def on_reset(self, cancellation_token: CancellationToken) -> None: """Handle reset command by deleting all messages in the thread.""" # Retrieve all message IDs in the thread - all_msgs = [] - after = None + all_msgs: List[str] = [] + after: str | NotGiven = NOT_GIVEN while True: - msgs = await cancellation_token.link_future( + msgs: AsyncCursorPage[Message] = await cancellation_token.link_future( asyncio.ensure_future(self._client.beta.threads.messages.list(self._thread_id, after=after)) ) for msg in msgs.data: @@ -169,46 +215,62 @@ async def on_reset(self, cancellation_token: CancellationToken): # Delete all messages for msg_id in all_msgs: - status = await cancellation_token.link_future( + status: MessageDeleted = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) ) ) assert status.deleted is True - async def on_upload_for_code_interpreter(self, file_path: str, cancellation_token: CancellationToken): + async def on_upload_for_code_interpreter( + self, file_paths: str | Iterable[str], cancellation_token: CancellationToken + ) -> None: """Handle file uploads for the code interpreter.""" - # Read the file content - async with aiofiles.open(file_path, mode="rb") as f: - file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) - file_name = os.path.basename(file_path) - - # Upload the file - file = await cancellation_token.link_future( - asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants")) - ) + if isinstance(file_paths, str): + file_paths = [file_paths] + + file_ids: List[str] = [] + for file_path in file_paths: + # Read the file content + async with aiofiles.open(file_path, mode="rb") as f: + file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) + file_name = os.path.basename(file_path) + + # Upload the file + file: FileObject = await cancellation_token.link_future( + asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants")) + ) + file_ids.append(file.id) + self._uploaded_file_ids.append(file.id) - # Update thread with the new file + # Update thread with the new files thread = await cancellation_token.link_future( asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) ) tool_resources: ToolResources = thread.tool_resources or ToolResources() - code_interpreter = tool_resources.code_interpreter or ToolResourcesFileSearch() - file_ids = code_interpreter.file_ids or [] - file_ids.append(file.id) - tool_resources.code_interpreter = ToolResourcesFileSearch(file_ids=file_ids) + code_interpreter: ToolResourcesCodeInterpreter = ( + tool_resources.code_interpreter or ToolResourcesCodeInterpreter() + ) + existing_file_ids: List[str] = code_interpreter.file_ids or [] + existing_file_ids.extend(file_ids) + tool_resources.code_interpreter = ToolResourcesCodeInterpreter(file_ids=existing_file_ids) await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.update( thread_id=self._thread_id, - tool_resources=tool_resources, + tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()), ) ) ) - async def on_upload_for_file_search(self, file_path: str, cancellation_token: CancellationToken): + async def on_upload_for_file_search( + self, file_paths: str | Iterable[str], cancellation_token: CancellationToken + ) -> None: """Handle file uploads for file search.""" + if isinstance(file_paths, str): + file_paths = [file_paths] + await self._ensure_initialized() # Check if file_search is enabled in tools @@ -219,7 +281,7 @@ async def on_upload_for_file_search(self, file_path: str, cancellation_token: Ca # Create vector store if not already created if self._vector_store_id is None: - vector_store = await cancellation_token.link_future( + vector_store: VectorStore = await cancellation_token.link_future( asyncio.ensure_future(self._client.beta.vector_stores.create()) ) self._vector_store_id = vector_store.id @@ -234,17 +296,43 @@ async def on_upload_for_file_search(self, file_path: str, cancellation_token: Ca ) ) - # Read the file content - async with aiofiles.open(file_path, mode="rb") as f: - file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) - file_name = os.path.basename(file_path) + # Read and prepare all files + files_to_upload: List[Tuple[str, bytes]] = [] + for file_path in file_paths: + async with aiofiles.open(file_path, mode="rb") as f: + file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) + file_name = os.path.basename(file_path) + files_to_upload.append((file_name, file_content)) - # Upload the file to the vector store - await cancellation_token.link_future( + # Upload all files to the vector store + batch = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.vector_stores.file_batches.upload_and_poll( vector_store_id=self._vector_store_id, - files=[(file_name, file_content)], + files=files_to_upload, ) ) ) + # Store file IDs from the batch + if batch.file_ids: + self._uploaded_file_ids.extend(batch.file_ids) + + async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None: + """Delete all files that were uploaded by this agent instance.""" + for file_id in self._uploaded_file_ids: + try: + await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id))) + except Exception as e: + event_logger.error(f"Failed to delete file {file_id}: {str(e)}") + self._uploaded_file_ids = [] + + async def delete_assistant(self, cancellation_token: CancellationToken) -> None: + """Delete the assistant if it was created by this instance.""" + if self._assistant is not None and not self._assistant_id: + try: + await cancellation_token.link_future( + asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id)) + ) + self._assistant = None + except Exception as e: + event_logger.error(f"Failed to delete assistant: {str(e)}") From 08ab8ec7cad35542a013c435de0da48c661596b0 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Sat, 16 Nov 2024 19:31:49 +1000 Subject: [PATCH 06/12] add tool execution --- .../agents/_openai_assistant_agent.py | 138 +++++++++++++----- 1 file changed, 105 insertions(+), 33 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py index 972c0d55b2f..2f3b4c277b5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py @@ -1,10 +1,13 @@ import asyncio +import json import logging import os -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, cast +from typing import Dict, Iterable, List, Optional, Sequence, cast import aiofiles from autogen_core.base import CancellationToken +from autogen_core.components import FunctionCall +from autogen_core.components.models._types import FunctionExecutionResult from autogen_core.components.tools import Tool from openai import NOT_GIVEN, AsyncClient, NotGiven from openai.pagination import AsyncCursorPage @@ -19,7 +22,16 @@ from openai.types.beta.vector_store import VectorStore from openai.types.shared_params.function_definition import FunctionDefinition -from autogen_agentchat.messages import ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage +from autogen_agentchat.messages import ( + AgentMessage, + ChatMessage, + HandoffMessage, + MultiModalMessage, + StopMessage, + TextMessage, + ToolCallMessage, + ToolCallResultMessage, +) from .. import EVENT_LOGGER_NAME from ..base import Response @@ -70,10 +82,12 @@ def __init__( if tools is None: tools = [] - # Convert autogen Tools to OpenAI Assistant tools + # Store original tools and converted tools separately + self._original_tools: List[Tool] = [] converted_tools: List[AssistantToolParam] = [] for tool in tools: if isinstance(tool, Tool): + self._original_tools.append(tool) converted_tools.append(_convert_tool_to_function_param(tool)) else: converted_tools.append(tool) @@ -114,6 +128,11 @@ async def _ensure_initialized(self) -> None: if self._thread is None: self._thread = await self._client.beta.threads.create() + @property + def produced_message_types(self) -> List[type[ChatMessage]]: + """The types of messages that the assistant agent produces.""" + return [TextMessage] + @property def _get_assistant_id(self) -> str: if self._assistant is None: @@ -126,6 +145,20 @@ def _thread_id(self) -> str: raise ValueError("Thread not initialized") return self._thread.id + async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str: + """Execute a tool call and return the result.""" + try: + if not self._original_tools: + raise ValueError("No tools are available.") + tool = next((t for t in self._original_tools if t.name == tool_call.name), None) + if tool is None: + raise ValueError(f"The tool '{tool_call.name}' is not available.") + arguments = json.loads(tool_call.arguments) + result = await tool.run_json(arguments, cancellation_token) + return tool.return_value_as_string(result) + except Exception as e: + return f"Error: {e}" + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: """Handle incoming messages and return a response.""" await self._ensure_initialized() @@ -137,6 +170,9 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: elif isinstance(message, (StopMessage, HandoffMessage)): await self.handle_text_message(message.content, cancellation_token) + # Inner messages for tool calls + inner_messages: List[AgentMessage] = [] + # Create and start a run run: Run = await cancellation_token.link_future( asyncio.ensure_future( @@ -148,7 +184,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: ) # Wait for run completion by polling - while run.status == "queued" or run.status == "in_progress": + while True: run = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.runs.retrieve( @@ -157,10 +193,55 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: ) ) ) - await asyncio.sleep(0.5) - if run.status == "failed": - raise ValueError(f"Run failed: {run.last_error}") + if run.status == "failed": + raise ValueError(f"Run failed: {run.last_error}") + + # If the run requires action (function calls), execute tools and continue + if run.status == "requires_action" and run.required_action is not None: + tool_calls: List[FunctionCall] = [] + for required_tool_call in run.required_action.submit_tool_outputs.tool_calls: + if required_tool_call.type == "function": + tool_calls.append( + FunctionCall( + id=required_tool_call.id, + name=required_tool_call.function.name, + arguments=required_tool_call.function.arguments, + ) + ) + + # Add tool call message to inner messages + tool_call_msg = ToolCallMessage(source=self.name, content=tool_calls) + inner_messages.append(tool_call_msg) + event_logger.debug(tool_call_msg) + + # Execute tool calls and get results + tool_outputs: List[FunctionExecutionResult] = [] + for tool_call in tool_calls: + result = await self._execute_tool_call(tool_call, cancellation_token) + tool_outputs.append(FunctionExecutionResult(content=result, call_id=tool_call.id)) + + # Add tool result message to inner messages + tool_result_msg = ToolCallResultMessage(source=self.name, content=tool_outputs) + inner_messages.append(tool_result_msg) + event_logger.debug(tool_result_msg) + + # Submit tool outputs back to the run + run = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.runs.submit_tool_outputs( + thread_id=self._thread_id, + run_id=run.id, + tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs], + ) + ) + ) + continue + + if run.status == "completed": + break + + await asyncio.sleep(0.5) # Get messages after run completion assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future( @@ -182,9 +263,9 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: if not text_content: raise ValueError(f"Expected text content in the last message: {last_message.content}") - # Return the assistant's response as a Response + # Return the assistant's response as a Response with inner messages chat_message = TextMessage(source=self.name, content=text_content[0].text.value) - return Response(chat_message=chat_message) + return Response(chat_message=chat_message, inner_messages=inner_messages) async def handle_text_message(self, content: str, cancellation_token: CancellationToken) -> None: """Handle regular text messages by adding them to the thread.""" @@ -222,27 +303,31 @@ async def on_reset(self, cancellation_token: CancellationToken) -> None: ) assert status.deleted is True - async def on_upload_for_code_interpreter( - self, file_paths: str | Iterable[str], cancellation_token: CancellationToken - ) -> None: - """Handle file uploads for the code interpreter.""" + async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]: + """Upload files and return their IDs.""" if isinstance(file_paths, str): file_paths = [file_paths] file_ids: List[str] = [] for file_path in file_paths: - # Read the file content async with aiofiles.open(file_path, mode="rb") as f: file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) file_name = os.path.basename(file_path) - # Upload the file file: FileObject = await cancellation_token.link_future( asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants")) ) file_ids.append(file.id) self._uploaded_file_ids.append(file.id) + return file_ids + + async def on_upload_for_code_interpreter( + self, file_paths: str | Iterable[str], cancellation_token: CancellationToken + ) -> None: + """Handle file uploads for the code interpreter.""" + file_ids = await self._upload_files(file_paths, cancellation_token) + # Update thread with the new files thread = await cancellation_token.link_future( asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) @@ -268,9 +353,6 @@ async def on_upload_for_file_search( self, file_paths: str | Iterable[str], cancellation_token: CancellationToken ) -> None: """Handle file uploads for file search.""" - if isinstance(file_paths, str): - file_paths = [file_paths] - await self._ensure_initialized() # Check if file_search is enabled in tools @@ -296,26 +378,16 @@ async def on_upload_for_file_search( ) ) - # Read and prepare all files - files_to_upload: List[Tuple[str, bytes]] = [] - for file_path in file_paths: - async with aiofiles.open(file_path, mode="rb") as f: - file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) - file_name = os.path.basename(file_path) - files_to_upload.append((file_name, file_content)) + file_ids = await self._upload_files(file_paths, cancellation_token) - # Upload all files to the vector store - batch = await cancellation_token.link_future( + # Create file batch with the file IDs + await cancellation_token.link_future( asyncio.ensure_future( - self._client.beta.vector_stores.file_batches.upload_and_poll( - vector_store_id=self._vector_store_id, - files=files_to_upload, + self._client.beta.vector_stores.file_batches.create_and_poll( + vector_store_id=self._vector_store_id, file_ids=file_ids ) ) ) - # Store file IDs from the batch - if batch.file_ids: - self._uploaded_file_ids.extend(batch.file_ids) async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None: """Delete all files that were uploaded by this agent instance.""" From 76351d93f81aff6e981c50e739a64ca93e8b430b Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Sun, 17 Nov 2024 14:19:01 +1000 Subject: [PATCH 07/12] fix tool call and add docstring --- .../agents/_openai_assistant_agent.py | 117 ++++++++++++++++-- 1 file changed, 105 insertions(+), 12 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py index 2f3b4c277b5..c59187efe16 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py @@ -2,13 +2,14 @@ import json import logging import os -from typing import Dict, Iterable, List, Optional, Sequence, cast +from typing import Dict, Generic, Iterable, List, Literal, Optional, Sequence, TypedDict, TypeVar, cast import aiofiles from autogen_core.base import CancellationToken from autogen_core.components import FunctionCall from autogen_core.components.models._types import FunctionExecutionResult from autogen_core.components.tools import Tool +from autogen_core.components.tools._base import BaseTool from openai import NOT_GIVEN, AsyncClient, NotGiven from openai.pagination import AsyncCursorPage from openai.types import FileObject @@ -16,11 +17,15 @@ from openai.types.beta.assistant import Assistant from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam from openai.types.beta.assistant_tool_param import AssistantToolParam +from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam +from openai.types.beta.file_search_tool_param import FileSearchToolParam from openai.types.beta.function_tool_param import FunctionToolParam from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter from openai.types.beta.threads import Message, MessageDeleted, Run from openai.types.beta.vector_store import VectorStore from openai.types.shared_params.function_definition import FunctionDefinition +from pydantic import BaseModel +from typing_extensions import Required from autogen_agentchat.messages import ( AgentMessage, @@ -38,11 +43,19 @@ from ._base_chat_agent import BaseChatAgent event_logger = logging.getLogger(EVENT_LOGGER_NAME) +ArgsT = TypeVar("ArgsT", bound=BaseModel, contravariant=True) +ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True) -def _convert_tool_to_function_param(tool: Tool) -> FunctionToolParam: +class BaseToolParam(TypedDict, Generic[ArgsT, ReturnT], total=False): + tool: Required[BaseTool[ArgsT, ReturnT]] + + type: Required[Literal["tool"]] + + +def _convert_tool_to_function_param(tool_param: BaseToolParam[ArgsT, ReturnT]) -> FunctionToolParam: """Convert an autogen Tool to an OpenAI Assistant function tool parameter.""" - schema = tool.schema + schema = tool_param["tool"].schema parameters: Dict[str, object] = {} if "parameters" in schema: parameters = { @@ -61,7 +74,72 @@ def _convert_tool_to_function_param(tool: Tool) -> FunctionToolParam: class OpenAIAssistantAgent(BaseChatAgent): - """An agent implementation that uses the OpenAI Assistant API to generate responses.""" + """An agent implementation that uses the OpenAI Assistant API to generate responses. + + This agent leverages the OpenAI Assistant API to create AI assistants with capabilities like: + - Code interpretation and execution + - File handling and search + - Custom function calling + - Multi-turn conversations + + The agent maintains a thread of conversation and can use various tools including: + - Code interpreter: For executing code and working with files + - File search: For searching through uploaded documents + - Custom functions: For extending capabilities with user-defined tools + + Key Features: + - Supports multiple file formats including code, documents, images + - Can handle up to 128 tools per assistant + - Maintains conversation context in threads + - Supports file uploads for code interpreter and search + - Vector store integration for efficient file search + - Automatic file parsing and embedding + + Example: + .. code-block:: python + + from openai import AsyncClient + from autogen_agentchat.agents import OpenAIAssistantAgent + + # Create an OpenAI client + client = AsyncClient(api_key="your-api-key", base_url="your-base-url") + + # Create an assistant with code interpreter + assistant = OpenAIAssistantAgent( + name="Python Helper", + description="Helps with Python programming", + client=client, + model="gpt-4", + instructions="You are a helpful Python programming assistant.", + tools=[{"type": "code_interpreter"}], + ) + + # Upload files for the assistant to use + await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token) + + # Get response from the assistant + response = await assistant.on_messages( + [TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token + ) + + # Clean up resources + await assistant.delete_uploaded_files(cancellation_token) + await assistant.delete_assistant(cancellation_token) + + Args: + name (str): Name of the assistant + description (str): Description of the assistant's purpose + client (AsyncClient): OpenAI API client instance + model (str): Model to use (e.g. "gpt-4") + instructions (str): System instructions for the assistant + tools (Optional[Iterable[CodeInterpreterToolParam | FileSearchToolParam | BaseToolParam[ArgsT, ReturnT]]]): Tools the assistant can use + assistant_id (Optional[str]): ID of existing assistant to use + metadata (Optional[object]): Additional metadata for the assistant + response_format (Optional[AssistantResponseFormatOptionParam]): Response format settings + temperature (Optional[float]): Temperature for response generation + tool_resources (Optional[ToolResources]): Additional tool configuration + top_p (Optional[float]): Top p sampling parameter + """ def __init__( self, @@ -70,7 +148,9 @@ def __init__( client: AsyncClient, model: str, instructions: str, - tools: Optional[Iterable[AssistantToolParam | Tool]] = None, + tools: Optional[ + Iterable[CodeInterpreterToolParam | FileSearchToolParam | BaseToolParam[ArgsT, ReturnT]] + ] = None, assistant_id: Optional[str] = None, metadata: Optional[object] = None, response_format: Optional[AssistantResponseFormatOptionParam] = None, @@ -86,18 +166,20 @@ def __init__( self._original_tools: List[Tool] = [] converted_tools: List[AssistantToolParam] = [] for tool in tools: - if isinstance(tool, Tool): - self._original_tools.append(tool) - converted_tools.append(_convert_tool_to_function_param(tool)) + if tool.get("type") == "tool": + base_tool = cast(BaseToolParam[ArgsT, ReturnT], tool) + self._original_tools.append(base_tool["tool"]) + converted_tools.append(_convert_tool_to_function_param(base_tool)) else: - converted_tools.append(tool) + # Not runtime checked but excluded base tool in the previous condition + converted_tools.append(cast(AssistantToolParam, tool)) self._client = client self._assistant: Optional[Assistant] = None self._thread: Optional[Thread] = None self._model = model self._instructions = instructions - self._tools = converted_tools + self._api_tools = converted_tools self._assistant_id = assistant_id self._metadata = metadata self._response_format = response_format @@ -117,7 +199,7 @@ async def _ensure_initialized(self) -> None: model=self._model, description=self.description, instructions=self._instructions, - tools=self._tools, + tools=self._api_tools, metadata=self._metadata, response_format=self._response_format if self._response_format else NOT_GIVEN, # type: ignore temperature=self._temperature, @@ -356,7 +438,7 @@ async def on_upload_for_file_search( await self._ensure_initialized() # Check if file_search is enabled in tools - if not any(tool.get("type") == "file_search" for tool in self._tools): + if not any(tool.get("type") == "file_search" for tool in self._api_tools): raise ValueError( "File search is not enabled for this assistant. Add a file_search tool when creating the assistant." ) @@ -408,3 +490,14 @@ async def delete_assistant(self, cancellation_token: CancellationToken) -> None: self._assistant = None except Exception as e: event_logger.error(f"Failed to delete assistant: {str(e)}") + + async def delete_vector_store(self, cancellation_token: CancellationToken) -> None: + """Delete the vector store if it was created by this instance.""" + if self._vector_store_id is not None: + try: + await cancellation_token.link_future( + asyncio.ensure_future(self._client.beta.vector_stores.delete(vector_store_id=self._vector_store_id)) + ) + self._vector_store_id = None + except Exception as e: + event_logger.error(f"Failed to delete vector store: {str(e)}") From e1c91fd8d7c4afa539885b3b6a4a95efdfdf0455 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Mon, 18 Nov 2024 15:26:54 +1000 Subject: [PATCH 08/12] abstract tools and support thread management --- .../agents/_openai_assistant_agent.py | 83 ++++++++++++++----- 1 file changed, 61 insertions(+), 22 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py index c59187efe16..1ef81a9b9d0 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py @@ -2,16 +2,29 @@ import json import logging import os -from typing import Dict, Generic, Iterable, List, Literal, Optional, Sequence, TypedDict, TypeVar, cast +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + TypeVar, + Union, + cast, +) import aiofiles from autogen_core.base import CancellationToken from autogen_core.components import FunctionCall from autogen_core.components.models._types import FunctionExecutionResult -from autogen_core.components.tools import Tool -from autogen_core.components.tools._base import BaseTool +from autogen_core.components.tools import FunctionTool, Tool from openai import NOT_GIVEN, AsyncClient, NotGiven from openai.pagination import AsyncCursorPage +from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads from openai.types import FileObject from openai.types.beta import thread_update_params from openai.types.beta.assistant import Assistant @@ -25,7 +38,6 @@ from openai.types.beta.vector_store import VectorStore from openai.types.shared_params.function_definition import FunctionDefinition from pydantic import BaseModel -from typing_extensions import Required from autogen_agentchat.messages import ( AgentMessage, @@ -47,15 +59,9 @@ ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True) -class BaseToolParam(TypedDict, Generic[ArgsT, ReturnT], total=False): - tool: Required[BaseTool[ArgsT, ReturnT]] - - type: Required[Literal["tool"]] - - -def _convert_tool_to_function_param(tool_param: BaseToolParam[ArgsT, ReturnT]) -> FunctionToolParam: +def _convert_tool_to_function_param(tool: Tool) -> FunctionToolParam: """Convert an autogen Tool to an OpenAI Assistant function tool parameter.""" - schema = tool_param["tool"].schema + schema = tool.schema parameters: Dict[str, object] = {} if "parameters" in schema: parameters = { @@ -111,7 +117,7 @@ class OpenAIAssistantAgent(BaseChatAgent): client=client, model="gpt-4", instructions="You are a helpful Python programming assistant.", - tools=[{"type": "code_interpreter"}], + tools=["code_interpreter"], ) # Upload files for the assistant to use @@ -132,7 +138,7 @@ class OpenAIAssistantAgent(BaseChatAgent): client (AsyncClient): OpenAI API client instance model (str): Model to use (e.g. "gpt-4") instructions (str): System instructions for the assistant - tools (Optional[Iterable[CodeInterpreterToolParam | FileSearchToolParam | BaseToolParam[ArgsT, ReturnT]]]): Tools the assistant can use + tools (Optional[Iterable[Union[Literal["code_interpreter", "file_search"], Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]]]): Tools the assistant can use assistant_id (Optional[str]): ID of existing assistant to use metadata (Optional[object]): Additional metadata for the assistant response_format (Optional[AssistantResponseFormatOptionParam]): Response format settings @@ -149,9 +155,15 @@ def __init__( model: str, instructions: str, tools: Optional[ - Iterable[CodeInterpreterToolParam | FileSearchToolParam | BaseToolParam[ArgsT, ReturnT]] + Iterable[ + Union[ + Literal["code_interpreter", "file_search"], + Tool | Callable[..., Any] | Callable[..., Awaitable[Any]], + ] + ] ] = None, assistant_id: Optional[str] = None, + thread_id: Optional[str] = None, metadata: Optional[object] = None, response_format: Optional[AssistantResponseFormatOptionParam] = None, temperature: Optional[float] = None, @@ -166,17 +178,29 @@ def __init__( self._original_tools: List[Tool] = [] converted_tools: List[AssistantToolParam] = [] for tool in tools: - if tool.get("type") == "tool": - base_tool = cast(BaseToolParam[ArgsT, ReturnT], tool) - self._original_tools.append(base_tool["tool"]) - converted_tools.append(_convert_tool_to_function_param(base_tool)) + if isinstance(tool, str): + if tool == "code_interpreter": + converted_tools.append(CodeInterpreterToolParam(type="code_interpreter")) + elif tool == "file_search": + converted_tools.append(FileSearchToolParam(type="file_search")) + elif isinstance(tool, Tool): + self._original_tools.append(tool) + converted_tools.append(_convert_tool_to_function_param(tool)) + elif callable(tool): + if hasattr(tool, "__doc__") and tool.__doc__ is not None: + description = tool.__doc__ + else: + description = "" + function_tool = FunctionTool(tool, description=description) + self._original_tools.append(function_tool) + converted_tools.append(_convert_tool_to_function_param(function_tool)) else: - # Not runtime checked but excluded base tool in the previous condition - converted_tools.append(cast(AssistantToolParam, tool)) + raise ValueError(f"Unsupported tool type: {type(tool)}") self._client = client self._assistant: Optional[Assistant] = None self._thread: Optional[Thread] = None + self._init_thread_id = thread_id self._model = model self._instructions = instructions self._api_tools = converted_tools @@ -208,13 +232,28 @@ async def _ensure_initialized(self) -> None: ) if self._thread is None: - self._thread = await self._client.beta.threads.create() + if self._init_thread_id: + self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id) + else: + self._thread = await self._client.beta.threads.create() @property def produced_message_types(self) -> List[type[ChatMessage]]: """The types of messages that the assistant agent produces.""" return [TextMessage] + @property + def threads(self) -> AsyncThreads: + return self._client.beta.threads + + @property + def runs(self) -> AsyncRuns: + return self._client.beta.threads.runs + + @property + def messages(self) -> AsyncMessages: + return self._client.beta.threads.messages + @property def _get_assistant_id(self) -> str: if self._assistant is None: From b037604a4f97cc615ef456d3abee34c1155f8bf8 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Mon, 18 Nov 2024 15:27:14 +1000 Subject: [PATCH 09/12] add tests --- .../autogen-agentchat/tests/test_openai_assistant_agent.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py diff --git a/python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py new file mode 100644 index 00000000000..e69de29bb2d From 4400d313f70eba87c103bc747269ca274f17e736 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Mon, 18 Nov 2024 15:33:21 +1000 Subject: [PATCH 10/12] removed unused typevars --- .../src/autogen_agentchat/agents/_openai_assistant_agent.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py index 1ef81a9b9d0..b9b8d8bf5d6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_openai_assistant_agent.py @@ -12,7 +12,6 @@ Literal, Optional, Sequence, - TypeVar, Union, cast, ) @@ -37,7 +36,6 @@ from openai.types.beta.threads import Message, MessageDeleted, Run from openai.types.beta.vector_store import VectorStore from openai.types.shared_params.function_definition import FunctionDefinition -from pydantic import BaseModel from autogen_agentchat.messages import ( AgentMessage, @@ -55,8 +53,6 @@ from ._base_chat_agent import BaseChatAgent event_logger = logging.getLogger(EVENT_LOGGER_NAME) -ArgsT = TypeVar("ArgsT", bound=BaseModel, contravariant=True) -ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True) def _convert_tool_to_function_param(tool: Tool) -> FunctionToolParam: From d4df3312e76e94341ab03e358d6d32336ac5f7be Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Tue, 19 Nov 2024 05:37:44 +1000 Subject: [PATCH 11/12] add unsaved test changes --- .../tests/test_openai_assistant_agent.py | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py index e69de29bb2d..73d7bcbf3cc 100644 --- a/python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py @@ -0,0 +1,134 @@ +import asyncio +import os +from enum import Enum +from typing import List, Optional + +import pytest +from autogen_agentchat.agents import OpenAIAssistantAgent +from autogen_agentchat.messages import TextMessage +from autogen_core.base import CancellationToken +from autogen_core.components.tools._base import BaseTool +from openai import AsyncAzureOpenAI +from pydantic import BaseModel + + +class QuestionType(str, Enum): + MULTIPLE_CHOICE = "MULTIPLE_CHOICE" + FREE_RESPONSE = "FREE_RESPONSE" + + +class Question(BaseModel): + question_text: str + question_type: QuestionType + choices: Optional[List[str]] = None + + +class DisplayQuizArgs(BaseModel): + title: str + questions: List[Question] + + +class DisplayQuizTool(BaseTool[DisplayQuizArgs, List[str]]): + def __init__(self): + super().__init__( + args_type=DisplayQuizArgs, + return_type=List[str], + name="display_quiz", + description=( + "Displays a quiz to the student and returns the student's responses. " + "A single quiz can have multiple questions." + ), + ) + + async def run(self, args: DisplayQuizArgs, cancellation_token: CancellationToken) -> List[str]: + responses = [] + for q in args.questions: + if q.question_type == QuestionType.MULTIPLE_CHOICE: + response = q.choices[0] if q.choices else "" + elif q.question_type == QuestionType.FREE_RESPONSE: + response = "Sample free response" + else: + response = "" + responses.append(response) + return responses + + +@pytest.fixture +def client(): + azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview") + api_key = os.getenv("AZURE_OPENAI_API_KEY") + + if not all([azure_endpoint, api_key]): + pytest.skip("Azure OpenAI credentials not found in environment variables") + + return AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key) + + +@pytest.fixture +def agent(client): + tools = [ + "code_interpreter", + "file_search", + DisplayQuizTool(), + ] + + return OpenAIAssistantAgent( + name="assistant", + instructions="Help the user with their task.", + model="gpt-4o-mini", + description="OpenAI Assistant Agent", + client=client, + tools=tools, + ) + + +@pytest.fixture +def cancellation_token(): + return CancellationToken() + + +@pytest.mark.asyncio +async def test_file_retrieval(agent, cancellation_token): + file_path = r"C:\Users\lpinheiro\Github\autogen-test\data\SampleBooks\jungle_book.txt" + await agent.on_upload_for_file_search(file_path, cancellation_token) + + message = TextMessage(source="user", content="What is the first sentence of the jungle scout book?") + response = await agent.on_messages([message], cancellation_token) + + assert response.chat_message.content is not None + assert isinstance(response.chat_message.content, str) + assert len(response.chat_message.content) > 0 + + await agent.delete_uploaded_files(cancellation_token) + await agent.delete_vector_store(cancellation_token) + await agent.delete_assistant(cancellation_token) + + +@pytest.mark.asyncio +async def test_code_interpreter(agent, cancellation_token): + message = TextMessage(source="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?") + response = await agent.on_messages([message], cancellation_token) + + assert response.chat_message.content is not None + assert isinstance(response.chat_message.content, str) + assert len(response.chat_message.content) > 0 + assert "x = 1" in response.chat_message.content.lower() + + await agent.delete_assistant(cancellation_token) + + +@pytest.mark.asyncio +async def test_quiz_creation(agent, cancellation_token): + message = TextMessage( + source="user", + content="Create a short quiz about basic math with one multiple choice question and one free response question.", + ) + response = await agent.on_messages([message], cancellation_token) + + assert response.chat_message.content is not None + assert isinstance(response.chat_message.content, str) + assert len(response.chat_message.content) > 0 + assert any(tool_msg.content for tool_msg in response.inner_messages if hasattr(tool_msg, "content")) + + await agent.delete_assistant(cancellation_token) From 9eca674bc495a274cfcd557fb52b8a5f99ee5b4b Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Tue, 19 Nov 2024 07:44:53 +1000 Subject: [PATCH 12/12] test typing fixes --- .../tests/test_openai_assistant_agent.py | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py index 73d7bcbf3cc..648c982bc98 100644 --- a/python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_openai_assistant_agent.py @@ -1,13 +1,12 @@ -import asyncio import os from enum import Enum -from typing import List, Optional +from typing import List, Literal, Optional, Union import pytest from autogen_agentchat.agents import OpenAIAssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core.base import CancellationToken -from autogen_core.components.tools._base import BaseTool +from autogen_core.components.tools._base import BaseTool, Tool from openai import AsyncAzureOpenAI from pydantic import BaseModel @@ -28,11 +27,15 @@ class DisplayQuizArgs(BaseModel): questions: List[Question] -class DisplayQuizTool(BaseTool[DisplayQuizArgs, List[str]]): - def __init__(self): +class QuizResponses(BaseModel): + responses: List[str] + + +class DisplayQuizTool(BaseTool[DisplayQuizArgs, QuizResponses]): + def __init__(self) -> None: super().__init__( args_type=DisplayQuizArgs, - return_type=List[str], + return_type=QuizResponses, name="display_quiz", description=( "Displays a quiz to the student and returns the student's responses. " @@ -40,8 +43,8 @@ def __init__(self): ), ) - async def run(self, args: DisplayQuizArgs, cancellation_token: CancellationToken) -> List[str]: - responses = [] + async def run(self, args: DisplayQuizArgs, cancellation_token: CancellationToken) -> QuizResponses: + responses: List[str] = [] for q in args.questions: if q.question_type == QuestionType.MULTIPLE_CHOICE: response = q.choices[0] if q.choices else "" @@ -50,11 +53,11 @@ async def run(self, args: DisplayQuizArgs, cancellation_token: CancellationToken else: response = "" responses.append(response) - return responses + return QuizResponses(responses=responses) @pytest.fixture -def client(): +def client() -> AsyncAzureOpenAI: azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview") api_key = os.getenv("AZURE_OPENAI_API_KEY") @@ -62,12 +65,14 @@ def client(): if not all([azure_endpoint, api_key]): pytest.skip("Azure OpenAI credentials not found in environment variables") + assert azure_endpoint is not None + assert api_key is not None return AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key) @pytest.fixture -def agent(client): - tools = [ +def agent(client: AsyncAzureOpenAI) -> OpenAIAssistantAgent: + tools: List[Union[Literal["code_interpreter", "file_search"], Tool]] = [ "code_interpreter", "file_search", DisplayQuizTool(), @@ -84,12 +89,12 @@ def agent(client): @pytest.fixture -def cancellation_token(): +def cancellation_token() -> CancellationToken: return CancellationToken() @pytest.mark.asyncio -async def test_file_retrieval(agent, cancellation_token): +async def test_file_retrieval(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None: file_path = r"C:\Users\lpinheiro\Github\autogen-test\data\SampleBooks\jungle_book.txt" await agent.on_upload_for_file_search(file_path, cancellation_token) @@ -106,7 +111,7 @@ async def test_file_retrieval(agent, cancellation_token): @pytest.mark.asyncio -async def test_code_interpreter(agent, cancellation_token): +async def test_code_interpreter(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None: message = TextMessage(source="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?") response = await agent.on_messages([message], cancellation_token) @@ -119,7 +124,7 @@ async def test_code_interpreter(agent, cancellation_token): @pytest.mark.asyncio -async def test_quiz_creation(agent, cancellation_token): +async def test_quiz_creation(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None: message = TextMessage( source="user", content="Create a short quiz about basic math with one multiple choice question and one free response question.", @@ -129,6 +134,7 @@ async def test_quiz_creation(agent, cancellation_token): assert response.chat_message.content is not None assert isinstance(response.chat_message.content, str) assert len(response.chat_message.content) > 0 + assert isinstance(response.inner_messages, list) assert any(tool_msg.content for tool_msg in response.inner_messages if hasattr(tool_msg, "content")) await agent.delete_assistant(cancellation_token)