Skip to content

Commit

Permalink
raise exception if cannot parse answer, do not return an arbitrary an…
Browse files Browse the repository at this point in the history
…swer
  • Loading branch information
mertyg committed Jul 12, 2024
1 parent 368241d commit 652e13c
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions textgrad/tasks/multimodal/mathvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
except ImportError:
raise ImportError("Please install the Levenshtein package using 'pip install python-Levenshtein' to use mathvista.")

from textgrad.engine.openai import ChatOpenAI

local_llm_engine = ChatOpenAI(model_string="gpt-3.5-turbo", is_multimodal=False)
print("Local OpenAI engine initialized.\n")

def compress_image(decoded_image, max_size_bytes=3.6*1024*1024):
# Convert image to RGB if it's in a mode that JPEG does not support
Expand All @@ -31,7 +27,6 @@ def compress_image(decoded_image, max_size_bytes=3.6*1024*1024):

width, height = decoded_image.size
while size > max_size_bytes:
print(f"Compressing image to {width}x{height}...")
width = int(width * 0.9)
height = int(height * 0.9)
resized_image = decoded_image.resize((width, height), Image.LANCZOS)
Expand Down Expand Up @@ -131,26 +126,24 @@ def extract_answer(response, problem, quick_extract=False):

# quick extraction
if quick_extract:
print("Quickly extracting answer...")
# The answer is "text". -> "text"
try:
result = re.search(r'The answer is "(.*)"\.', response)
if result:
extraction = result.group(1)
return extraction
except:
pass
except Exception as e:
raise Exception(f"Error in extracting answer for {pid}: {e}. Remove this line responsibly.")

# general extraction
try:
from textgrad.engine.openai import ChatOpenAI
local_llm_engine = ChatOpenAI(model_string="gpt-3.5-turbo", is_multimodal=False)

full_prompt = create_test_prompt(demo_prompt, query, response)
extraction = local_llm_engine(full_prompt)
return extraction
except Exception as e:
print(e)
print(f"Error in extracting answer for {pid}")

return ""
raise Exception(f"Error in extracting answer for {pid}: {e}")


def get_most_similar(prediction, choices):
Expand Down

0 comments on commit 652e13c

Please sign in to comment.