Skip to content

Commit

Permalink
Fix file assignment in gradio interface #125
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed May 27, 2024
1 parent d58ee7c commit 3296f85
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 17 deletions.
76 changes: 65 additions & 11 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import queue
import threading
import time
import uuid
from enum import Enum
from typing import List, TypedDict, Callable, Any, Dict, Literal, Union, Optional
Expand All @@ -19,6 +20,7 @@
from agency_swarm.threads import Thread
from agency_swarm.tools import BaseTool, FileSearch, CodeInterpreter
from agency_swarm.user import User
from agency_swarm.util.files import determine_file_type
from agency_swarm.util.shared_state import SharedState

from agency_swarm.util.streaming import AgencyEventHandler
Expand Down Expand Up @@ -222,8 +224,10 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs):
else:
js = js.replace("{theme}", "light")

message_file_ids = []
attachments = []
images = []
message_file_names = None
uploading_files = False
recipient_agents = [agent.name for agent in self.main_recipients]
recipient_agent = self.main_recipients[0]

Expand All @@ -236,17 +240,20 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs):
value=recipient_agent.name)
msg = gr.Textbox(label="Your Message", lines=4)
with gr.Column(scale=1):
file_upload = gr.Files(label="Files", type="filepath")
file_upload = gr.Files(label="OpenAI Files", type="filepath")
button = gr.Button(value="Send", variant="primary")

def handle_dropdown_change(selected_option):
nonlocal recipient_agent
recipient_agent = self._get_agent_by_name(selected_option)

def handle_file_upload(file_list):
nonlocal message_file_ids
nonlocal attachments
nonlocal message_file_names
message_file_ids = []
nonlocal uploading_files
nonlocal images
uploading_files = True
attachments = []
message_file_names = []
if file_list:
try:
Expand All @@ -257,19 +264,46 @@ def handle_file_upload(file_list):
file=f,
purpose="assistants"
)
message_file_ids.append(file.id)

file_type = determine_file_type(file_obj.name)

if file_type == "assistants.code_interpreter":
attachments.append({
"file_id": file.id,
"tools": [{"type": "code_interpreter"}]
})
elif file_type == "vision":
images.append({
"type": "image_file",
"image_file": {"file_id": file.id}
})
else:
attachments.append({
"file_id": file.id,
"tools": [{"type": "file_search"}]
})

message_file_names.append(file.filename)
print(f"Uploaded file ID: {file.id}")
return message_file_ids
return attachments
except Exception as e:
print(f"Error: {e}")
return str(e)
finally:
uploading_files = False

uploading_files = False
return "No files uploaded"

def user(user_message, history):
if not user_message.strip():
return user_message, history

nonlocal message_file_names
nonlocal uploading_files
nonlocal images
nonlocal attachments
nonlocal recipient_agent

if history is None:
history = []
Expand Down Expand Up @@ -386,18 +420,38 @@ def bot(original_message, history):
if not original_message:
return "", history

nonlocal message_file_ids
nonlocal attachments
nonlocal message_file_names
nonlocal recipient_agent
print("Message files: ", message_file_ids)
# Replace this with your actual chatbot logic
nonlocal images
nonlocal uploading_files

if uploading_files:
history.append([None, "Uploading files... Please wait."])
yield "", history
return "", history

print("Message files: ", attachments)
print("Images: ", images)

if images and len(images) > 0:
original_message = [
{
"type": "text",
"text": original_message,
},
*images
]


completion_thread = threading.Thread(target=self.get_completion_stream, args=(
original_message, GradioEventHandler, message_file_ids, recipient_agent))
original_message, GradioEventHandler, [], recipient_agent, "", attachments, None))
completion_thread.start()

message_file_ids = []
attachments = []
message_file_names = []
images = []
uploading_files = False

new_message = True
while True:
Expand Down
2 changes: 1 addition & 1 deletion agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_completion_stream(self,
yield_messages=False)

def get_completion(self,
message: str,
message: str | List[dict],
message_files: List[str] = None,
attachments: Optional[List[dict]] = None,
recipient_agent=None,
Expand Down
3 changes: 2 additions & 1 deletion agency_swarm/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .cli.create_agent_template import create_agent_template
from .cli.import_agent import import_agent
from .oai import set_openai_key, get_openai_client, set_openai_client
from .oai import set_openai_key, get_openai_client, set_openai_client
from .files import determine_file_type
19 changes: 19 additions & 0 deletions agency_swarm/util/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import mimetypes

def determine_file_type(file_path):
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type:
if mime_type in [
'application/json', 'text/csv', 'application/xml',
'application/vnd.ms-excel', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'application/zip'
]:
return "assistants.code_interpreter"
elif mime_type in [
'text/plain', 'text/markdown', 'application/pdf',
'application/msword', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
]:
return "assistants.file_search"
elif mime_type.startswith('image/'):
return "vision"
return "assistants.file_search"
Empty file added tests/demos/__init__.py
Empty file.
8 changes: 4 additions & 4 deletions tests/demos/demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from agency_swarm import set_openai_key, Agent

from agency_swarm.agency.agency import Agency
from agency_swarm.tools.oai import FileSearch
from agency_swarm.tools.oai import FileSearch, CodeInterpreter

ceo = Agent(name="CEO",
description="Responsible for client communication, task planning and management.",
instructions="Analyze uploaded files with myfiles_browser tool.", # can be a file like ./instructions.md
tools=[FileSearch])
tools=[FileSearch, CodeInterpreter])


test_agent = Agent(name="Test Agent1",
Expand All @@ -31,7 +31,7 @@
ceo, test_agent, test_agent2
], shared_instructions="")

# agency.demo_gradio()
agency.demo_gradio()

print(agency.get_completion("Hello", recipient_agent=test_agent, yield_messages=False))
# print(agency.get_completion("Hello", recipient_agent=test_agent, yield_messages=False))

0 comments on commit 3296f85

Please sign in to comment.