diff --git a/lmms_eval/api/samplers.py b/lmms_eval/api/samplers.py index e1a3dda21..c64b28399 100644 --- a/lmms_eval/api/samplers.py +++ b/lmms_eval/api/samplers.py @@ -1,33 +1,9 @@ import datasets -from typing import Callable, Iterable, Optional +from typing import Callable, Iterable, Optional, List from abc import ABC, abstractmethod -class ContextObject(ABC): - @abstractmethod - def get_text(self): - raise NotImplementedError - - def __str__(self): - return self.get_text() - - -class QAPairs(ContextObject): - def __init__(self, question: str, answer: Optional[str] = None, delimiter="\n", role_question: str = "USER: ", role_answer: str = "ASSISTANT: "): - self.question = question - self.answer = answer - self.delimiter = delimiter - self.role_question = role_question - self.role_answer = role_answer - - def get_text(self): - if self.answer is None: - return self.role_question + self.question + self.delimiter - else: - return self.role_question + self.question + self.delimiter + self.role_answer + self.answer - - -class LazyLoadedImages(ContextObject): +class LazyLoadedImages(object): def __init__(self, data_frame, index, doc_to_visual: Callable, image_tokens=""): self.data_frame: datasets.Dataset = data_frame self.index = index @@ -65,6 +41,80 @@ def get_text(self, lazy: bool = True): return " ".join([self.image_tokens] * self.get_num_images()) +class QAPairs(object): + def __init__( + self, + data_frame, + index, + *, + doc=None, + include_answer: bool = True, + doc_to_text: Callable, + doc_to_target: Optional[Callable] = None, + doc_to_choice: Optional[Callable] = None, + doc_to_visual: Optional[Callable] = None, + target_delimiter="\n", + delimiter="\n", + image_tokens="", + role_question="USER: ", + role_answer="ASSISTANT: ", + config=None, + ): + self.data_frame: datasets.Dataset = data_frame + self.index = index + self.target_delimiter = target_delimiter + self.doc_to_text = doc_to_text + self.doc_to_target = doc_to_target + self.doc_to_choice = doc_to_choice + self.delimiter = delimiter + if doc_to_visual: + self.vision = LazyLoadedImages(data_frame, index, doc_to_visual, image_tokens) + else: + self.vision = None + self.role_question = role_question + self.role_answer = role_answer + if doc is None: + doc = data_frame[index] + self.config = config + self.question = self._get_question(doc) + self.answer = self._get_target(doc) if include_answer else None + + def _get_question(self, doc): + text = self.doc_to_text(doc) + return text if (self.doc_to_choice is None or isinstance(text, str)) else self.doc_to_choice(doc)[text] + + def _get_target(self, doc): + return ( + str(self.doc_to_target(doc)[0]) + if type(self.doc_to_target(doc)) is list + else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) + ) + + def get_text(self): + if self.answer is None: + return self.role_question + self.question + self.delimiter + else: + return self.role_question + self.question + self.delimiter + self.role_answer + self.answer + + def __str__(self): + return self.get_text() + + def get_visions(self): + if self.vision: + return self.vision.get_images() + else: + return [] + + def already_have_image_token(self, image_token): + return image_token in self.question or (self.answer and image_token in self.answer) + + def num_images(self): + if self.vision: + return self.vision.get_num_images() + else: + return 0 + + class Context(object): def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str = "\n", description=None): self.task = task @@ -78,78 +128,78 @@ def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str self.target_delimiter = target_delimiter self.few_shot_delimiter = few_shot_delimiter - self.contexts = [] + self.contexts: List[QAPairs] = [] - if description: - self.add_description(description) + self.description = description def add_description(self, description): - self.contexts = [description] + self.contexts - - def get_question(self, doc): - text = self.doc_to_text(doc) - return text if (self.doc_to_choice is None or isinstance(text, str)) else self.doc_to_choice(doc)[text] - - def get_target(self, doc): - return ( - str(self.doc_to_target(doc)[0]) - if type(self.doc_to_target(doc)) is list - else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) + self.description = description + + def add_in_context_example(self, doc, data_frame, index): + # question = self.get_question(doc) + # if data_frame and index: + # visual = LazyLoadedImages(data_frame, index, self.doc_to_visual) + # else: + # visual = None + # target = self.doc_to_target(doc) + # if visual: + # self.contexts.append(visual) + self.contexts.append( + QAPairs( + data_frame, + index, + doc=doc, + doc_to_text=self.doc_to_text, + doc_to_target=self.doc_to_target, + doc_to_choice=self.doc_to_choice, + doc_to_visual=self.doc_to_visual, + delimiter=self.target_delimiter, + config=self.config, + ) ) - - def add_in_context_example(self, doc, data_frame=None, index=None): - question = self.get_question(doc) - if data_frame and index: - visual = LazyLoadedImages(data_frame, index, self.doc_to_visual) - else: - visual = None - target = self.doc_to_target(doc) - if visual: - self.contexts.append(visual) - self.contexts.append(QAPairs(question, target, self.target_delimiter)) - self.contexts.append(self.few_shot_delimiter) + # self.contexts.append(self.few_shot_delimiter) def add_question(self, doc, data_frame=None, index=None): - question = self.get_question(doc) - if data_frame and index: - visual = LazyLoadedImages(data_frame, index, self.doc_to_visual) - else: - visual = None - if visual: - self.contexts.append(visual) - self.contexts.append(QAPairs(question)) + # question = self.get_question(doc) + # if data_frame and index: + # visual = LazyLoadedImages(data_frame, index, self.doc_to_visual) + # else: + # visual = None + # if visual: + # self.contexts.append(visual) + self.contexts.append( + QAPairs( + data_frame, + index, + doc=doc, + doc_to_text=self.doc_to_text, + doc_to_target=self.doc_to_target, + doc_to_choice=self.doc_to_choice, + doc_to_visual=self.doc_to_visual, + delimiter=self.target_delimiter, + include_answer=False, + config=self.config, + ) + ) # self.contexts.append(self.target_delimiter) + + def already_have_image_token(self, image_token): + for context in self.contexts: + if context.already_have_image_token(image_token): + return True + return False - def get_text(self, *, image_tokens="", lazy=True): + def get_text(self): texts = [] - vision = [] - already_have_images = False for context in self.contexts: - if isinstance(context, str) and image_tokens in context: - already_have_images = True - break - if already_have_images: - image_tokens = "" - for context in self.contexts: - if isinstance(context, LazyLoadedImages): - if isinstance(image_tokens, str): - if lazy: - texts.append(image_tokens) - else: - now_vision = context.get_images(self.doc_to_visual) - vision.extend(now_vision) - texts.append(image_tokens * len(now_vision)) - else: - texts.append(image_tokens(context)) - else: - texts.append(str(context)) - if lazy: - return "".join(texts) - else: - return "".join(texts), vision + texts.append(str(context)) + return "".join(texts) def get_visions(self): - return sum([context.get_images(self.doc_to_visual) for context in self.contexts if isinstance(context, LazyLoadedImages)], start=[]) + visions = [] + for context in self.contexts: + visions.extend(context.get_visions()) + return visions def extend(self, context): if isinstance(context, list): diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index a830fd8d9..80134af78 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -335,21 +335,20 @@ def _collate(x): conv = conv_templates[self.conv_template].copy() - num_image_tokens = 0 from lmms_eval.api.samplers import LazyLoadedImages, QAPairs + already_have_image_token = context.already_have_image_token(DEFAULT_IMAGE_TOKEN) + for obj in context.contexts: - if isinstance(obj, LazyLoadedImages): - num_image_tokens += obj.get_num_images() - elif isinstance(obj, QAPairs): - if num_image_tokens == 0: - question = obj.question - else: - question = " ".join(num_image_tokens * [DEFAULT_IMAGE_TOKEN]) + "\n" + obj.question - answer = obj.answer - conv.append_message(conv.roles[0], question) - conv.append_message(conv.roles[1], answer) - num_image_tokens = 0 + if already_have_image_token or obj.num_images() == 0: + question = obj.question + else: + question = " ".join(obj.num_images() * [DEFAULT_IMAGE_TOKEN]) + "\n" + obj.question + if context.description: + question = context.description + "\n" + question + answer = obj.answer + conv.append_message(conv.roles[0], question) + conv.append_message(conv.roles[1], answer) # conv.append_message(conv.roles[0], question) # conv.append_message(conv.roles[1], None)