Skip to content

Commit

Permalink
Merge pull request #64 from zou-group/textgrad-vision-tasks
Browse files Browse the repository at this point in the history
Adding textgrad multimodal tasks
  • Loading branch information
mertyg authored Jul 13, 2024
2 parents da28c3d + 67b7417 commit a15e7b6
Show file tree
Hide file tree
Showing 4 changed files with 592 additions and 7 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,13 @@ We are grateful for all the help we got from our contributors!
<sub><b>Mert Yuksekgonul</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/nihalnayak">
<img src="https://avatars.githubusercontent.com/u/5679782?v=4" width="100;" alt="nihalnayak"/>
<br />
<sub><b>Nihal Nayak</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/sugatoray">
<img src="https://avatars.githubusercontent.com/u/10201242?v=4" width="100;" alt="sugatoray"/>
Expand All @@ -347,13 +354,6 @@ We are grateful for all the help we got from our contributors!
<br />
<sub><b>David Ruan</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/nihalnayak">
<img src="https://avatars.githubusercontent.com/u/5679782?v=4" width="100;" alt="nihalnayak"/>
<br />
<sub><b>Nihal Nayak</b></sub>
</a>
</td>
</tr>
<tr>
Expand Down
16 changes: 16 additions & 0 deletions textgrad/tasks/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from textgrad.engine import EngineLM


def load_multimodal_instance_task(task_name: str, evaluation_api: EngineLM, *args, **kwargs):
if task_name == "mathvista":
from textgrad.tasks.multimodal.mathvista import MathVistaDataset
test_set = MathVistaDataset(evaluation_api=evaluation_api, split="testmini", *args, **kwargs)
return test_set

elif task_name == "scienceqa":
from textgrad.tasks.multimodal.scienceqa import ScienceQADataset
test_set = ScienceQADataset(evaluation_api=evaluation_api, split="test", *args, **kwargs)
return test_set

else:
raise ValueError(f"Instance task {task_name} not found.")
342 changes: 342 additions & 0 deletions textgrad/tasks/multimodal/mathvista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
import re
import io
import platformdirs
from PIL import Image

from textgrad.tasks.base import Dataset
from textgrad.loss import ImageQALoss
from textgrad.variable import Variable
try:
from Levenshtein import distance
except ImportError:
raise ImportError("Please install the Levenshtein package using 'pip install python-Levenshtein' to use mathvista.")


def compress_image(decoded_image, max_size_bytes=3.6*1024*1024):
# Convert image to RGB if it's in a mode that JPEG does not support
if decoded_image.mode not in ['RGB', 'L']:
decoded_image = decoded_image.convert('RGB')

buffer = io.BytesIO()
decoded_image.save(buffer, format='JPEG')
size = buffer.tell()

if size <= max_size_bytes:
buffer.seek(0)
return buffer.getvalue()

width, height = decoded_image.size
while size > max_size_bytes:
width = int(width * 0.9)
height = int(height * 0.9)
resized_image = decoded_image.resize((width, height), Image.LANCZOS)

buffer = io.BytesIO()
resized_image.save(buffer, format='JPEG')
size = buffer.tell()

if width <= 1 or height <= 1:
raise ValueError("Unable to compress image to the desired size without excessive loss of resolution")

buffer.seek(0)
return buffer.getvalue()


# Demos (pids = 852, 104, 824, 506, 540) from MathVista
demo_prompt = """
Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.
Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.
Question: Which number is missing?
Model response: The number missing in the sequence is 14.
Extracted answer: 14
Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.
Question: What is the fraction of females facing the camera?
Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.
Extracted answer: 0.6
Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end.
Question: How much money does Luca need to buy a sour apple candy and a butterscotch candy? (Unit: $)
Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.
Extracted answer: 1.45
Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.
Question: Between which two years does the line graph saw its maximum peak?
Model response: The line graph saw its maximum peak between 2007 and 2008.
Extracted answer: [2007, 2008]
Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.
Question: What fraction of the shape is blue?\nChoices:\n(A) 3/11\n(B) 8/11\n(C) 6/11\n(D) 3/5
Model response: The correct answer is (B) 8/11.
Extracted answer: B
"""


def verify_extraction(extraction):
extraction = extraction.strip()
if extraction == "" or extraction == None:
return False
return True


def create_test_prompt(demo_prompt, query, response):
demo_prompt = demo_prompt.strip()
test_prompt = f"{query}\n\n{response}"
full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: "
return full_prompt


def extract_answer(response, problem, quick_extract=False):
question_type = problem['question_type']
answer_type = problem['answer_type']
choices = problem['choices']
query = problem['query']
pid = problem['pid']

if response == "":
return ""

if question_type == 'multi_choice' and response in choices:
return response

if answer_type == "integer":
try:
extraction = int(response)
return str(extraction)
except:
pass

if answer_type == "float":
try:
extraction = str(float(response))
return extraction
except:
pass

if quick_extract:
try:
result = re.search(r'The answer is "(.*)"\.', response)
if result:
extraction = result.group(1)
return extraction
except Exception as e:
raise Exception(f"Error in extracting answer for {pid}: {e}. Remove this line responsibly.")

try:
from textgrad.engine.openai import ChatOpenAI
local_llm_engine = ChatOpenAI(model_string="gpt-3.5-turbo", is_multimodal=False)

full_prompt = create_test_prompt(demo_prompt, query, response)
extraction = local_llm_engine(full_prompt)
return extraction
except Exception as e:
raise Exception(f"Error in extracting answer for {pid}: {e}")


def get_most_similar(prediction, choices):
"""
Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction
"""
distances = [distance(prediction, choice) for choice in choices]
ind = distances.index(min(distances))
return choices[ind]


def normalize_extracted_answer(extraction, question_data):
"""
Normalize the extracted answer to match the answer type
"""
choices = question_data["choices"]
question_type = question_data["question_type"]
answer_type = question_data["answer_type"]
precision = question_data["precision"]

if question_type == 'multi_choice':
# make sure the extraction is a string
if isinstance(extraction, str):
extraction = extraction.strip()
else:
try:
extraction = str(extraction)
except:
extraction = ""

# extract "A" from "(A) text"
letter = re.findall(r'\(([a-zA-Z])\)', extraction)
if len(letter) > 0:
extraction = letter[0].upper()

options = [chr(ord('A') + i) for i in range(len(choices))]

if extraction in options:
# convert option letter to text, e.g. "A" -> "text"
ind = options.index(extraction)
extraction = choices[ind]
else:
# select the most similar option
extraction = get_most_similar(extraction, choices)
assert extraction in choices

elif answer_type == 'integer':
try:
extraction = str(int(float(extraction)))
except:
extraction = None

elif answer_type == 'float':
try:
extraction = str(round(float(extraction), int(precision)))
except:
extraction = None

elif answer_type == 'list':
try:
extraction = str(extraction)
except:
extraction = None

return extraction


def safe_equal(prediction, answer):
"""
Check if the prediction is equal to the answer, even if they are of different types
"""
try:
if prediction == answer:
return True
return False
except Exception as e:
print(e)
return False


class MathVistaDataset(Dataset):
def __init__(self, evaluation_api:str, root: str=None, split: str="testmini", task_instruction: str=None, evaluation_instruction: str=None, *args, **kwargs):
"""MathVista dataset from HF."""
from datasets import load_dataset
if root is None:
root = platformdirs.user_cache_dir("textgrad")
self.root = root
assert split in ["testmini", "test"]
self.data = load_dataset("AI4Math/MathVista", cache_dir=root, split=split)
self.split = split
self.evaluation_api = evaluation_api
self.anwer_extraction_openai_engine = "gpt-3.5-turbo" # robust enough for answer extraction
self.task_instruction = self.get_default_task_instruction(task_instruction) # NOTE: check the task instruction
self.evaluation_instruction = self.get_default_evaluation_instruction(evaluation_instruction) # NOTE: check the evaluation instruction

def __getitem__(self, index):
row = self.data[index]
pid = row["pid"]
decoded_image = row["decoded_image"]
choices = row["choices"]
unit = row["unit"]
precision = row["precision"]
answer = row["answer"]
question_type = row["question_type"]
answer_type = row["answer_type"]
metadata = row["metadata"]
query = row["query"]
query = f"{self.task_instruction}\n{query}" # NOTE: Add the task description

# NOTE: convert image to bytes
if "claude" in self.evaluation_api.model_string:
image_bytes = compress_image(decoded_image)
else:
buffer = io.BytesIO()
decoded_image.save(buffer, format='png')
image_bytes = buffer.getvalue()
buffer.close()

# NOTE: ques_data stores other fields that might be useful later
ques_data = {
"pid": pid,
"query": query,
"choices": choices,
"unit": unit,
"precision": precision,
"answer": answer,
"question_type": question_type,
"answer_type": answer_type,
"metadata": metadata
}
test_time_objective = self._get_instance_test_time_objective(query, image_bytes)
instance_eval_fn = self._get_instance_eval_fn(query, answer, ques_data)
return image_bytes, query, answer, ques_data, test_time_objective, instance_eval_fn # NOTE: check the sample format

def __len__(self):
return len(self.data)

def get_default_task_instruction(self, instruction):
if instruction is not None:
print("Using user-defined task instruction:\n", instruction, "\n")
task_instruction = instruction
else:
task_instruction = "You will answer a mathematical reasoning question based on an image. Please ensure you accurately interpret the image and think step by step."
return task_instruction

def get_default_evaluation_instruction(self, instruction):
if instruction is not None:
print("Using user-defined evaluation instruction:\n", instruction, "\n")
evaluation_instruction = instruction
else:
evaluation_instruction = "Please evaluate the existing answer to the visual math problem without solving it yourself. Verify that the answer provides accurate reasoning logic to address the question."
return evaluation_instruction

def create_test_prompt(demo_prompt, query, response):
demo_prompt = demo_prompt.strip()
test_prompt = f"{query}\n\n{response}"
full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: "
return full_prompt

def _get_instance_test_time_objective(self, question: str, image: bytes):
"""Define the loss function for the test time optimization."""
eval_fn = ImageQALoss(evaluation_instruction=self.evaluation_instruction, engine=self.evaluation_api)

def test_time_objective(instance: Variable):
var_image = Variable(image, role_description="image input", requires_grad=False)
var_question = Variable(question, role_description="question input", requires_grad=False)
return eval_fn(question=var_question, image=var_image, response=instance)

return test_time_objective

def eval_extraction_and_matching(self, response_text, correct_answer, question_data):
# Extract the precited answer text from the response
extracted_answer = extract_answer(response_text, question_data)

# Normalize the extracted answer to match the answer type
normalized_answer = normalize_extracted_answer(extracted_answer, question_data)

# Verify the prediction is true or false
true_false = safe_equal(normalized_answer, correct_answer)

# Calculate the score and store the result data
# NOTE: check the result data format
score = 1 if true_false else 0
result_data = {
"extracted_answer": extracted_answer,
"normalized_answer": normalized_answer,
"true_false": true_false
}
return score, result_data

def _get_instance_eval_fn(self, question_prompt: str, answer: str, ques_data: dict):
"""
Define the evaluation function for scoring the response.
Extraxct the short answer from the response and compare it with the ground truth.
"""
# NOTE: check the evaluation function format
eval_extraction_based_fn = lambda response: self.eval_extraction_and_matching(response.value, answer, ques_data)
return eval_extraction_based_fn
Loading

0 comments on commit a15e7b6

Please sign in to comment.