diff --git a/textgrad/tasks/multimodal/__init__.py b/textgrad/tasks/multimodal/__init__.py new file mode 100644 index 0000000..267779f --- /dev/null +++ b/textgrad/tasks/multimodal/__init__.py @@ -0,0 +1,16 @@ +from textgrad.engine import EngineLM + + +def load_multimodal_instance_task(task_name: str, evaluation_api: EngineLM, *args, **kwargs): + if task_name == "mathvista": + from textgrad.tasks.multimodal.mathvista import MathVistaDataset + test_set = MathVistaDataset(evaluation_api=evaluation_api, split="testmini", *args, **kwargs) + return test_set + + elif task_name == "scienceqa": + from textgrad.tasks.multimodal.scienceqa import ScienceQADataset + test_set = ScienceQADataset(evaluation_api=evaluation_api, split="test", *args, **kwargs) + return test_set + + else: + raise ValueError(f"Instance task {task_name} not found.") diff --git a/textgrad/tasks/multimodal/mathvista.py b/textgrad/tasks/multimodal/mathvista.py new file mode 100644 index 0000000..ce88220 --- /dev/null +++ b/textgrad/tasks/multimodal/mathvista.py @@ -0,0 +1,300 @@ +import re +import io +import platformdirs +from PIL import Image + +from textgrad.tasks.base import Dataset +from textgrad.loss import ImageQALoss +from textgrad.variable import Variable +try: + from Levenshtein import distance +except ImportError: + raise ImportError("Please install the Levenshtein package using 'pip install python-Levenshtein' to use mathvista.") + +def compress_image(decoded_image, max_size_bytes=3.6*1024*1024): + # Convert image to RGB if it's in a mode that JPEG does not support + if decoded_image.mode not in ['RGB', 'L']: + decoded_image = decoded_image.convert('RGB') + + buffer = io.BytesIO() + decoded_image.save(buffer, format='JPEG') + size = buffer.tell() + + if size <= max_size_bytes: + buffer.seek(0) + return buffer.getvalue() + + width, height = decoded_image.size + while size > max_size_bytes: + print(f"Compressing image to {width}x{height}...") + width = int(width * 0.9) + height = int(height * 0.9) + resized_image = decoded_image.resize((width, height), Image.LANCZOS) + + buffer = io.BytesIO() + resized_image.save(buffer, format='JPEG') + size = buffer.tell() + + if width <= 1 or height <= 1: + raise ValueError("Unable to compress image to the desired size without excessive loss of resolution") + + buffer.seek(0) + return buffer.getvalue() + + +def verify_extraction(extraction): + extraction = extraction.strip() + if extraction == "" or extraction == None: + return False + return True + + +def create_test_prompt(demo_prompt, query, response): + demo_prompt = demo_prompt.strip() + test_prompt = f"{query}\n\n{response}" + full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: " + return full_prompt + + +def extract_answer(response, problem, quick_extract=False): + question_type = problem['question_type'] + answer_type = problem['answer_type'] + choices = problem['choices'] + query = problem['query'] + pid = problem['pid'] + + if response == "": + return "" + + if question_type == 'multi_choice' and response in choices: + return response + + if answer_type == "integer": + try: + extraction = int(response) + return str(extraction) + except: + pass + + if answer_type == "float": + try: + extraction = str(float(response)) + return extraction + except: + pass + + # quick extraction + if quick_extract: + print("Quickly extracting answer...") + try: + result = re.search(r'The answer is "(.*)"\.', response) + if result: + extraction = result.group(1) + return extraction + except: + pass + + else: + raise NotImplementedError("Extraction using LLMs are to-be-implemented.") + + +def get_most_similar(prediction, choices): + """ + Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction + """ + distances = [distance(prediction, choice) for choice in choices] + ind = distances.index(min(distances)) + return choices[ind] + # return min(choices, key=lambda choice: distance(prediction, choice)) + + +def normalize_extracted_answer(extraction, question_data): + """ + Normalize the extracted answer to match the answer type + """ + choices = question_data["choices"] + question_type = question_data["question_type"] + answer_type = question_data["answer_type"] + precision = question_data["precision"] + + if question_type == 'multi_choice': + # make sure the extraction is a string + if isinstance(extraction, str): + extraction = extraction.strip() + else: + try: + extraction = str(extraction) + except: + extraction = "" + + # extract "A" from "(A) text" + letter = re.findall(r'\(([a-zA-Z])\)', extraction) + if len(letter) > 0: + extraction = letter[0].upper() + + options = [chr(ord('A') + i) for i in range(len(choices))] + + if extraction in options: + # convert option letter to text, e.g. "A" -> "text" + ind = options.index(extraction) + extraction = choices[ind] + else: + # select the most similar option + extraction = get_most_similar(extraction, choices) + assert extraction in choices + + elif answer_type == 'integer': + try: + extraction = str(int(float(extraction))) + except: + extraction = None + + elif answer_type == 'float': + try: + extraction = str(round(float(extraction), int(precision))) + except: + extraction = None + + elif answer_type == 'list': + try: + extraction = str(extraction) + except: + extraction = None + + return extraction + + +def safe_equal(prediction, answer): + """ + Check if the prediction is equal to the answer, even if they are of different types + """ + try: + if prediction == answer: + return True + return False + except Exception as e: + print(e) + return False + + +class MathVistaDataset(Dataset): + def __init__(self, evaluation_api:str, root: str=None, split: str="testmini", task_instruction: str=None, evaluation_instruction: str=None, *args, **kwargs): + """MathVista dataset from HF.""" + from datasets import load_dataset + if root is None: + root = platformdirs.user_cache_dir("textgrad") + self.root = root + assert split in ["testmini", "test"] + self.data = load_dataset("AI4Math/MathVista", cache_dir=root, split=split) + self.split = split + self.evaluation_api = evaluation_api + self.anwer_extraction_openai_engine = "gpt-3.5-turbo" # robust enough for answer extraction + self.task_instruction = self.get_default_task_instruction(task_instruction) # NOTE: check the task instruction + self.evaluation_instruction = self.get_default_evaluation_instruction(evaluation_instruction) # NOTE: check the evaluation instruction + + def __getitem__(self, index): + row = self.data[index] + pid = row["pid"] + # question = row["question"] + # image = row["image"] + decoded_image = row["decoded_image"] + choices = row["choices"] + unit = row["unit"] + precision = row["precision"] + answer = row["answer"] + question_type = row["question_type"] + answer_type = row["answer_type"] + metadata = row["metadata"] + query = row["query"] + query = f"{self.task_instruction}\n{query}" # NOTE: Add the task description + + # NOTE: convert image to bytes + if "claude" in self.evaluation_api.model_string: + # TODO @lupantech This does not seem neat. + image_bytes = compress_image(decoded_image) + else: + buffer = io.BytesIO() + decoded_image.save(buffer, format='png') + image_bytes = buffer.getvalue() + buffer.close() + + # NOTE: ques_data stores other fields that might be useful later + ques_data = { + "pid": pid, + "query": query, + "choices": choices, + "unit": unit, + "precision": precision, + "answer": answer, + "question_type": question_type, + "answer_type": answer_type, + "metadata": metadata + } + test_time_objective = self._get_instance_test_time_objective(query, image_bytes) + instance_eval_fn = self._get_instance_eval_fn(query, answer, ques_data) + return image_bytes, query, answer, ques_data, test_time_objective, instance_eval_fn # NOTE: check the sample format + + def __len__(self): + return len(self.data) + + def get_default_task_instruction(self, instruction): + if instruction is not None: + print("Using user-defined task instruction:\n", instruction, "\n") + task_instruction = instruction + else: + task_instruction = "You will answer a mathematical reasoning question based on an image. Please ensure you accurately interpret the image and think step by step." + return task_instruction + + def get_default_evaluation_instruction(self, instruction): + if instruction is not None: + print("Using user-defined evaluation instruction:\n", instruction, "\n") + evaluation_instruction = instruction + else: + evaluation_instruction = "Please evaluate the existing answer to the visual math problem without solving it yourself. Verify that the answer provides accurate reasoning logic to address the question." + return evaluation_instruction + + def create_test_prompt(demo_prompt, query, response): + demo_prompt = demo_prompt.strip() + test_prompt = f"{query}\n\n{response}" + full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: " + return full_prompt + + def _get_instance_test_time_objective(self, question: str, image: bytes): + """Define the loss function for the test time optimization.""" + eval_fn = ImageQALoss(evaluation_instruction=self.evaluation_instruction, engine=self.evaluation_api) + + def test_time_objective(instance: Variable): + var_image = Variable(image, role_description="image input", requires_grad=False) + var_question = Variable(question, role_description="question input", requires_grad=False) + return eval_fn(question=var_question, image=var_image, response=instance) + + return test_time_objective + + def eval_extraction_and_matching(self, response_text, correct_answer, question_data): + # Extract the precited answer text from the response + extracted_answer = extract_answer(response_text, question_data) + + # Normalize the extracted answer to match the answer type + normalized_answer = normalize_extracted_answer(extracted_answer, question_data) + + # Verify the prediction is true or false + true_false = safe_equal(normalized_answer, correct_answer) + + # Calculate the score and store the result data + # NOTE: check the result data format + score = 1 if true_false else 0 + result_data = { + "extracted_answer": extracted_answer, + "normalized_answer": normalized_answer, + "true_false": true_false + } + return score, result_data + + def _get_instance_eval_fn(self, question_prompt: str, answer: str, ques_data: dict): + """ + Define the evaluation function for scoring the response. + Extraxct the short answer from the response and compare it with the ground truth. + """ + # NOTE: check the evaluation function format + eval_extraction_based_fn = lambda response: self.eval_extraction_and_matching(response.value, answer, ques_data) + return eval_extraction_based_fn diff --git a/textgrad/tasks/multimodal/scienceqa.py b/textgrad/tasks/multimodal/scienceqa.py new file mode 100644 index 0000000..18a4633 --- /dev/null +++ b/textgrad/tasks/multimodal/scienceqa.py @@ -0,0 +1,227 @@ +import platformdirs +import re +import os +import io +from collections import Counter +from PIL import Image + +from textgrad.tasks.base import Dataset +from textgrad.loss import ImageQALoss +from textgrad.variable import Variable + + +def compress_image(decoded_image, max_size_bytes=3.6*1024*1024): + # First, try saving as PNG without any compression + buffer = io.BytesIO() + decoded_image.save(buffer, format='PNG') + size = buffer.tell() + + # If the original PNG is already small enough, return it + if size <= max_size_bytes: + buffer.seek(0) + return buffer.getvalue() + + # If PNG is too large, resize the image + width, height = decoded_image.size + while size > max_size_bytes: + print(f"Compressing image to {width}x{height}...") + width = int(width * 0.9) + height = int(height * 0.9) + resized_image = decoded_image.resize((width, height), Image.LANCZOS) + + buffer = io.BytesIO() + resized_image.save(buffer, format='PNG') + size = buffer.tell() + + if width <= 1 or height <= 1: + raise ValueError("Unable to compress image to the desired size without excessive loss of resolution") + + buffer.seek(0) + return buffer.getvalue() + +def preprocess_text(text): + # Remove punctuation and convert to lowercase + text = re.sub(r'[^\w\s]', '', text).lower() + return text + +def find_most_similar_choice(text, choices): + # Preprocess the given text + text = preprocess_text(text) + text_words = Counter(text.split()) + scores = [] + for choice in choices: + choice_text = preprocess_text(choice) + choice_words = Counter(choice_text.split()) + common_words = sum((text_words & choice_words).values()) # the number of common words + scores.append(common_words) + most_similar_index = scores.index(max(scores)) + return most_similar_index + +def extract_answer(response_text): + """ + Extract the answer from the response. + For example, "xxxxx. Answer: (A) a gas." -> "A" + If extraction fails, return the entire string after "Answer: ". + """ + # Attempt to match the format "Answer: (A)" or "Answer: (a)" in case-sensitive manner + match = re.search(r"Answer: \(([A-Z])\)", response_text) + if match: + return match.group(1).upper() # Return as uppercase + else: + # Fallback: match the format "Answer: " followed by any characters until the next period or end of line + fallback_match = re.search(r"Answer: ([^\.]+)", response_text) + if fallback_match: + return fallback_match.group(1).strip() + return response_text + +def normalize_extracted_answer(extracted_answer, question_data, options): + # Normalize the extracted answer + choices = question_data["choices"] + options = options[:len(choices)] + + # 'A' -> one of the choices + if extracted_answer in options: + normalized_answer = options.index(extracted_answer) + return normalized_answer + + # '(a) a gas.' + for choice in choices: + if choice.lower() in extracted_answer.lower(): + normalized_answer = choices.index(choice) + return normalized_answer + + # find the most similar choice + normalized_answer = find_most_similar_choice(extracted_answer, choices) + return normalized_answer + +def safe_equal(a, b): + # Check if two intergers are equal + return a == b + +class ScienceQADataset(Dataset): + def __init__(self, evaluation_api:str, root: str=None, split: str="test", task_instruction: str=None, evaluation_instruction: str=None, *args, **kwargs): + """ScienceQA dataset from HF.""" + if root is None: + root = platformdirs.user_cache_dir("textgrad") + self.root = root + self.split = split + assert split in ["test"] + self.data = self.load_scienceqa_data() + self.options = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"] + self.evaluation_api = evaluation_api + self.task_instruction = self.get_default_task_instruction(task_instruction) # NOTE: check the task instruction + self.evaluation_instruction = self.get_default_evaluation_instruction(evaluation_instruction) # NOTE: check the evaluation instruction + + def __getitem__(self, index): + row = self.data[index] + pid = row["pid"] + image = row["image"] + question = row["question"] + choices = row["choices"] + answer = row["answer"] + hint = row["hint"] + + query = f"{self.task_instruction}" # NOTE: Add the task description + if hint is not None and len(hint) > 0: + query += f"\nContext: {hint}" + query += f"\nQuestion: {question}" + if choices: + choice_list = [] + for i, c in enumerate(choices): + choice_list.append("({}) {}".format(self.options[i], c)) + choice_txt = " ".join(choice_list) + query += f"\nChoices: {choice_txt}" + + # NOTE: convert image to bytes + if "claude" in self.evaluation_api.model_string: + image_bytes = compress_image(image) + # print("Image size:", len(image_bytes)) + else: + buffer = io.BytesIO() + image.save(buffer, format='png') + image_bytes = buffer.getvalue() + buffer.close() + + # NOTE: ques_data stores other fields that might be useful later + ques_data = { + "pid": pid, + "question": question, + "choices": choices, + "hint": hint + } + test_time_objective = self._get_instance_test_time_objective(query, image_bytes) + instance_eval_fn = self._get_instance_eval_fn(query, answer, ques_data) + return image_bytes, query, answer, ques_data, test_time_objective, instance_eval_fn # NOTE: check the sample format + + def __len__(self): + return len(self.data) + + def get_default_task_instruction(self, instruction): + if instruction is not None: + print("Using user-defined task instruction:\n", instruction, "\n") + task_instruction = instruction + else: + task_instruction = "You will answer a scientific question based on an image. Please ensure you accurately interpret the image and think step by step. The last line of your answer should be formatted as follows: 'Answer: (X) Your Option.'" + return task_instruction + + def load_scienceqa_data(self): + scienceqa_dir = os.path.join(self.root, "scienceqa") + try: + from datasets import Dataset + data = Dataset.load_from_disk(scienceqa_dir) + print("Loaded ScienceQA dataset from cache.") + return data + except FileNotFoundError: + from datasets import load_dataset + data = load_dataset("derek-thomas/ScienceQA", split=self.split) + data_img = data.filter(lambda x: x['image'] is not None) # select examples with a non-empty question image + data_img = data_img.map(lambda x, i: {'pid': str(i), **x}, with_indices=True) # add index ID (string) for each example + data_img.save_to_disk(scienceqa_dir) + print("Loaded ScienceQA dataset from HF.") + return data_img + + def get_default_evaluation_instruction(self, instruction): + if instruction is not None: + print("Using user-defined evaluation instruction:\n", instruction, "\n") + evaluation_instruction = instruction + else: + evaluation_instruction = "Please evaluate the existing answer to the visual scientific problem without solving it yourself. Verify that the answer accurately understands the image, provides appropriate knowledge and reasoning logic to address the question." + return evaluation_instruction + + def _get_instance_test_time_objective(self, question: str, image: bytes): + """Define the loss function for the test time optimization.""" + eval_fn = ImageQALoss(evaluation_instruction=self.evaluation_instruction, engine=self.evaluation_api) + def test_time_objective(instance: Variable): + var_image = Variable(image, role_description="image input", requires_grad=False) + var_question = Variable(question, role_description="question input", requires_grad=False) + return eval_fn(question=var_question, image=var_image, response=instance) + return test_time_objective + + def eval_extraction_and_matching(self, response_text, correct_answer, question_data): + # Extract the precited answer text from the response + extracted_answer = extract_answer(response_text) + + # Normalize the extracted answer to match the answer type + normalized_answer = normalize_extracted_answer(extracted_answer, question_data, self.options) + + # Verify the prediction is true or false + true_false = safe_equal(normalized_answer, correct_answer) + + # Calculate the score and store the result data + # NOTE: check the result data format + score = 1 if true_false else 0 + result_data = { + "extracted_answer": extracted_answer, + "normalized_answer": normalized_answer, + "true_false": true_false + } + return score, result_data + + def _get_instance_eval_fn(self, question_prompt: str, answer: str, ques_data: dict): + """ + Define the evaluation function for scoring the response. + Extraxct the short answer from the response and compare it with the ground truth. + """ + # NOTE: check the evaluation function format + eval_extraction_based_fn = lambda response: self.eval_extraction_and_matching(response.value, answer, ques_data) + return eval_extraction_based_fn