From 83859b7f00e67c7bbd1863db01ea0d9f089081c1 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Tue, 28 May 2024 09:32:49 +0400 Subject: [PATCH] Fix compat issues when initializing assistants created with v1 api, remove udpates when initializing assistant from id #128 --- agency_swarm/agency/agency.py | 30 +++++++++++++++++------------- agency_swarm/agents/agent.py | 21 +++++++++++++++++---- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 096be21f..c8020814 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -22,6 +22,7 @@ from agency_swarm.user import User from agency_swarm.util.files import determine_file_type from agency_swarm.util.shared_state import SharedState +from openai.types.beta.threads.runs.tool_call import ToolCall from agency_swarm.util.streaming import AgencyEventHandler @@ -258,21 +259,18 @@ def handle_file_upload(file_list): if file_list: try: for file_obj in file_list: + file_type = determine_file_type(file_obj.name) + purpose = "assistants" if file_type != "vision" else "vision" + tools = [{"type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [{"type": "file_search"}] + with open(file_obj.name, 'rb') as f: # Upload the file to OpenAI file = self.main_thread.client.files.create( file=f, - purpose="assistants" + purpose=purpose ) - - file_type = determine_file_type(file_obj.name) - if file_type == "assistants.code_interpreter": - attachments.append({ - "file_id": file.id, - "tools": [{"type": "code_interpreter"}] - }) - elif file_type == "vision": + if file_type == "vision": images.append({ "type": "image_file", "image_file": {"file_id": file.id} @@ -280,9 +278,9 @@ def handle_file_upload(file_list): else: attachments.append({ "file_id": file.id, - "tools": [{"type": "file_search"}] + "tools": tools }) - + message_file_names.append(file.filename) print(f"Uploaded file ID: {file.id}") return attachments @@ -358,7 +356,10 @@ def on_text_delta(self, delta, snapshot): chatbot_queue.put(delta.value) @override - def on_tool_call_created(self, tool_call): + def on_tool_call_created(self, tool_call: ToolCall): + if isinstance(tool_call, dict): + tool_call = ToolCall(**tool_call) + # TODO: add support for code interpreter and retirieval tools if tool_call.type == "function": chatbot_queue.put("[new_message]") @@ -367,7 +368,10 @@ def on_tool_call_created(self, tool_call): chatbot_queue.put(self.message_output.get_formatted_header() + "\n") @override - def on_tool_call_done(self, snapshot): + def on_tool_call_done(self, snapshot: ToolCall): + if isinstance(snapshot, dict): + snapshot = ToolCall(**snapshot) + self.message_output = None # TODO: add support for code interpreter and retirieval tools diff --git a/agency_swarm/agents/agent.py b/agency_swarm/agents/agent.py index 67678423..7530e431 100644 --- a/agency_swarm/agents/agent.py +++ b/agency_swarm/agents/agent.py @@ -194,9 +194,22 @@ def init_oai(self): self.model = self.model or self.assistant.model self.tool_resources = self.tool_resources or self.assistant.tool_resources.model_dump() - # update assistant if parameters are different - if not self._check_parameters(self.assistant.model_dump()): - self._update_assistant() + for tool in self.assistant.tools: + if tool.type == "function": + # function tools must be added manually + continue + elif tool.type == "file_search": + self.add_tool(FileSearch) + elif tool.type == "code_interpreter": + self.add_tool(CodeInterpreter) + elif tool.type == "retrieval": + self.add_tool(Retrieval) + else: + raise Exception("Invalid tool type.") + + # # update assistant if parameters are different + # if not self._check_parameters(self.assistant.model_dump()): + # self._update_assistant() return self @@ -213,7 +226,7 @@ def init_oai(self): # update assistant if parameters are different if not self._check_parameters(self.assistant.model_dump()): - print("Updating assistant... " + self.name) + print("Updating agent... " + self.name) self._update_assistant() if self.assistant.tool_resources: