Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joshbickett committed Jan 7, 2024
2 parents d3d2ca4 + 23d9058 commit 47778a3
Show file tree
Hide file tree
Showing 22 changed files with 722 additions and 316 deletions.
12 changes: 2 additions & 10 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,5 @@ cython_debug/
.DS_Store

# Avoid sending testing screenshots up
screenshot.png
screenshot_with_grid.png
screenshot_with_labeled_grid.png
screenshot_mini.png
screenshot_mini_with_grid.png
grid_screenshot.png
grid_reflection_screenshot.png
reflection_screenshot.png
summary_screenshot.png
operate/screenshots/
*.png
operate/screenshots/
221 changes: 179 additions & 42 deletions operate/actions/api_interactions.py → operate/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,65 @@
import json
import base64
import re
import io
import asyncio
import aiohttp

from PIL import Image
from ultralytics import YOLO
import google.generativeai as genai
from operate.config.settings import Config
from operate.exceptions.exceptions import ModelNotRecognizedException
from operate.utils.screenshot_util import capture_screen_with_cursor, add_grid_to_image, capture_mini_screenshot_with_cursor
from operate.utils.action_util import get_last_assistant_message
from operate.utils.prompt_util import format_vision_prompt, format_accurate_mode_vision_prompt,format_summary_prompt
from operate.settings import Config
from operate.exceptions import ModelNotRecognizedException
from operate.utils.screenshot import (
capture_screen_with_cursor,
add_grid_to_image,
capture_mini_screenshot_with_cursor,
)
from operate.utils.os import get_last_assistant_message
from operate.prompts import (
format_vision_prompt,
format_accurate_mode_vision_prompt,
format_summary_prompt,
format_decision_prompt,
format_label_prompt,
)


from operate.utils.label import (
add_labels,
parse_click_content,
get_click_position_in_percent,
get_label_coordinates,
)
from operate.utils.style import (
ANSI_GREEN,
ANSI_RED,
ANSI_RESET,
)


# Load configuration
config = Config()

client = config.initialize_openai_client()

yolo_model = YOLO("./operate/model/weights/best.pt") # Load your trained model

def get_next_action(model, messages, objective, accurate_mode):
if model == "gpt-4-vision-preview":
content = get_next_action_from_openai(
messages, objective, accurate_mode)
return content

async def get_next_action(model, messages, objective):
if model == "gpt-4":
return call_gpt_4_v(messages, objective)
if model == "gpt-4-with-som":
return await call_gpt_4_v_labeled(messages, objective)
elif model == "agent-1":
return "coming soon"
elif model == "gemini-pro-vision":
content = get_next_action_from_gemini_pro_vision(
messages, objective
)
return content
return call_gemini_pro_vision(messages, objective)

raise ModelNotRecognizedException(model)


def get_next_action_from_openai(messages, objective, accurate_mode):
def call_gpt_4_v(messages, objective):
"""
Get the next action for Self-Operating Computer
"""
Expand Down Expand Up @@ -95,32 +124,14 @@ def get_next_action_from_openai(messages, objective, accurate_mode):

content = response.choices[0].message.content

if accurate_mode:
if content.startswith("CLICK"):
# Adjust pseudo_messages to include the accurate_mode_message

click_data = re.search(r"CLICK \{ (.+) \}", content).group(1)
click_data_json = json.loads(f"{{{click_data}}}")
prev_x = click_data_json["x"]
prev_y = click_data_json["y"]

if config.debug:
print(
f"Previous coords before accurate tuning: prev_x {prev_x} prev_y {prev_y}"
)
content = accurate_mode_double_check(
"gpt-4-vision-preview", pseudo_messages, prev_x, prev_y
)
assert content != "ERROR", "ERROR: accurate_mode_double_check failed"

return content

except Exception as e:
print(f"Error parsing JSON: {e}")
return "Failed take action after looking at the screenshot"


def get_next_action_from_gemini_pro_vision(messages, objective):
def call_gemini_pro_vision(messages, objective):
"""
Get the next action for Self-Operating Computer using Gemini Pro Vision
"""
Expand Down Expand Up @@ -172,14 +183,13 @@ def get_next_action_from_gemini_pro_vision(messages, objective):
return "Failed take action after looking at the screenshot"


# This function is not used. `-accurate` mode was removed for now until a new PR fixes it.
def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
"""
Reprompt OAI with additional screenshot of a mini screenshot centered around the cursor for further finetuning of clicked location
"""
print("[get_next_action_from_gemini_pro_vision] accurate_mode_double_check")
try:
screenshot_filename = os.path.join(
"screenshots", "screenshot_mini.png")
screenshot_filename = os.path.join("screenshots", "screenshot_mini.png")
capture_mini_screenshot_with_cursor(
file_path=screenshot_filename, x=prev_x, y=prev_y
)
Expand All @@ -191,8 +201,7 @@ def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
with open(new_screenshot_filename, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")

accurate_vision_prompt = format_accurate_mode_vision_prompt(
prev_x, prev_y)
accurate_vision_prompt = format_accurate_mode_vision_prompt(prev_x, prev_y)

accurate_mode_message = {
"role": "user",
Expand Down Expand Up @@ -234,7 +243,7 @@ def summarize(model, messages, objective):
capture_screen_with_cursor(screenshot_filename)

summary_prompt = format_summary_prompt(objective)

if model == "gpt-4-vision-preview":
with open(screenshot_filename, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
Expand Down Expand Up @@ -266,7 +275,135 @@ def summarize(model, messages, objective):
)
content = summary_message.text
return content

except Exception as e:
print(f"Error in summarize: {e}")
return "Failed to summarize the workflow"
return "Failed to summarize the workflow"


async def call_gpt_4_v_labeled(messages, objective):
time.sleep(1)
try:
screenshots_dir = "screenshots"
if not os.path.exists(screenshots_dir):
os.makedirs(screenshots_dir)

screenshot_filename = os.path.join(screenshots_dir, "screenshot.png")
# Call the function to capture the screen with the cursor
capture_screen_with_cursor(screenshot_filename)

with open(screenshot_filename, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")

previous_action = get_last_assistant_message(messages)

img_base64_labeled, img_base64_original, label_coordinates = add_labels(
img_base64, yolo_model
)

decision_prompt = format_decision_prompt(objective, previous_action)
labeled_click_prompt = format_label_prompt(objective)

click_message = {
"role": "user",
"content": [
{"type": "text", "text": labeled_click_prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64_labeled}"
},
},
],
}
decision_message = {
"role": "user",
"content": [
{"type": "text", "text": decision_prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64_original}"
},
},
],
}

click_messages = messages.copy()
click_messages.append(click_message)
decision_messages = messages.copy()
decision_messages.append(decision_message)

click_future = fetch_openai_response_async(click_messages)
decision_future = fetch_openai_response_async(decision_messages)

click_response, decision_response = await asyncio.gather(
click_future, decision_future
)

# Extracting the message content from the ChatCompletionMessage object
click_content = click_response.get("choices")[0].get("message").get("content")

decision_content = (
decision_response.get("choices")[0].get("message").get("content")
)

if not decision_content.startswith("CLICK"):
return decision_content

label_data = parse_click_content(click_content)

if label_data and "label" in label_data:
coordinates = get_label_coordinates(label_data["label"], label_coordinates)
image = Image.open(
io.BytesIO(base64.b64decode(img_base64))
) # Load the image to get its size
image_size = image.size # Get the size of the image (width, height)
click_position_percent = get_click_position_in_percent(
coordinates, image_size
)
if not click_position_percent:
print(
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] Failed to get click position in percent. Trying another method {ANSI_RESET}"
)
return call_gpt_4_v(messages, objective)

x_percent = f"{click_position_percent[0]:.2f}%"
y_percent = f"{click_position_percent[1]:.2f}%"
click_action = f'CLICK {{ "x": "{x_percent}", "y": "{y_percent}", "description": "{label_data["decision"]}", "reason": "{label_data["reason"]}" }}'

else:
print(
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] No label found. Trying another method {ANSI_RESET}"
)
return call_gpt_4_v(messages, objective)

return click_action

except Exception as e:
print(
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] Something went wrong. Trying another method {ANSI_RESET}"
)
return call_gpt_4_v(messages, objective)


async def fetch_openai_response_async(messages):
url = "https://api.openai.com/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {config.openai_api_key}",
}
data = {
"model": "gpt-4-vision-preview",
"messages": messages,
"frequency_penalty": 1,
"presence_penalty": 1,
"temperature": 0.7,
"max_tokens": 300,
}

async with aiohttp.ClientSession() as session:
async with session.post(
url, headers=headers, data=json.dumps(data)
) as response:
return await response.json()
Empty file removed operate/actions/__init__.py
Empty file.
Empty file removed operate/config/__init__.py
Empty file.
Loading

0 comments on commit 47778a3

Please sign in to comment.