Skip to content

Commit

Permalink
make contexts lists only have qa pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Apr 1, 2024
1 parent 7607c28 commit d5ac624
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 98 deletions.
222 changes: 136 additions & 86 deletions lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,9 @@
import datasets
from typing import Callable, Iterable, Optional
from typing import Callable, Iterable, Optional, List
from abc import ABC, abstractmethod


class ContextObject(ABC):
@abstractmethod
def get_text(self):
raise NotImplementedError

def __str__(self):
return self.get_text()


class QAPairs(ContextObject):
def __init__(self, question: str, answer: Optional[str] = None, delimiter="\n", role_question: str = "USER: ", role_answer: str = "ASSISTANT: "):
self.question = question
self.answer = answer
self.delimiter = delimiter
self.role_question = role_question
self.role_answer = role_answer

def get_text(self):
if self.answer is None:
return self.role_question + self.question + self.delimiter
else:
return self.role_question + self.question + self.delimiter + self.role_answer + self.answer


class LazyLoadedImages(ContextObject):
class LazyLoadedImages(object):
def __init__(self, data_frame, index, doc_to_visual: Callable, image_tokens="<image>"):
self.data_frame: datasets.Dataset = data_frame
self.index = index
Expand Down Expand Up @@ -65,6 +41,80 @@ def get_text(self, lazy: bool = True):
return " ".join([self.image_tokens] * self.get_num_images())


class QAPairs(object):
def __init__(
self,
data_frame,
index,
*,
doc=None,
include_answer: bool = True,
doc_to_text: Callable,
doc_to_target: Optional[Callable] = None,
doc_to_choice: Optional[Callable] = None,
doc_to_visual: Optional[Callable] = None,
target_delimiter="\n",
delimiter="\n",
image_tokens="<image>",
role_question="USER: ",
role_answer="ASSISTANT: ",
config=None,
):
self.data_frame: datasets.Dataset = data_frame
self.index = index
self.target_delimiter = target_delimiter
self.doc_to_text = doc_to_text
self.doc_to_target = doc_to_target
self.doc_to_choice = doc_to_choice
self.delimiter = delimiter
if doc_to_visual:
self.vision = LazyLoadedImages(data_frame, index, doc_to_visual, image_tokens)
else:
self.vision = None
self.role_question = role_question
self.role_answer = role_answer
if doc is None:
doc = data_frame[index]
self.config = config
self.question = self._get_question(doc)
self.answer = self._get_target(doc) if include_answer else None

def _get_question(self, doc):
text = self.doc_to_text(doc)
return text if (self.doc_to_choice is None or isinstance(text, str)) else self.doc_to_choice(doc)[text]

def _get_target(self, doc):
return (
str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list
else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
)

def get_text(self):
if self.answer is None:
return self.role_question + self.question + self.delimiter
else:
return self.role_question + self.question + self.delimiter + self.role_answer + self.answer

def __str__(self):
return self.get_text()

def get_visions(self):
if self.vision:
return self.vision.get_images()
else:
return []

def already_have_image_token(self, image_token):
return image_token in self.question or (self.answer and image_token in self.answer)

def num_images(self):
if self.vision:
return self.vision.get_num_images()
else:
return 0


class Context(object):
def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str = "\n", description=None):
self.task = task
Expand All @@ -78,78 +128,78 @@ def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str
self.target_delimiter = target_delimiter
self.few_shot_delimiter = few_shot_delimiter

self.contexts = []
self.contexts: List[QAPairs] = []

if description:
self.add_description(description)
self.description = description

def add_description(self, description):
self.contexts = [description] + self.contexts

def get_question(self, doc):
text = self.doc_to_text(doc)
return text if (self.doc_to_choice is None or isinstance(text, str)) else self.doc_to_choice(doc)[text]

def get_target(self, doc):
return (
str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list
else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
self.description = description

def add_in_context_example(self, doc, data_frame, index):
# question = self.get_question(doc)
# if data_frame and index:
# visual = LazyLoadedImages(data_frame, index, self.doc_to_visual)
# else:
# visual = None
# target = self.doc_to_target(doc)
# if visual:
# self.contexts.append(visual)
self.contexts.append(
QAPairs(
data_frame,
index,
doc=doc,
doc_to_text=self.doc_to_text,
doc_to_target=self.doc_to_target,
doc_to_choice=self.doc_to_choice,
doc_to_visual=self.doc_to_visual,
delimiter=self.target_delimiter,
config=self.config,
)
)

def add_in_context_example(self, doc, data_frame=None, index=None):
question = self.get_question(doc)
if data_frame and index:
visual = LazyLoadedImages(data_frame, index, self.doc_to_visual)
else:
visual = None
target = self.doc_to_target(doc)
if visual:
self.contexts.append(visual)
self.contexts.append(QAPairs(question, target, self.target_delimiter))
self.contexts.append(self.few_shot_delimiter)
# self.contexts.append(self.few_shot_delimiter)

def add_question(self, doc, data_frame=None, index=None):
question = self.get_question(doc)
if data_frame and index:
visual = LazyLoadedImages(data_frame, index, self.doc_to_visual)
else:
visual = None
if visual:
self.contexts.append(visual)
self.contexts.append(QAPairs(question))
# question = self.get_question(doc)
# if data_frame and index:
# visual = LazyLoadedImages(data_frame, index, self.doc_to_visual)
# else:
# visual = None
# if visual:
# self.contexts.append(visual)
self.contexts.append(
QAPairs(
data_frame,
index,
doc=doc,
doc_to_text=self.doc_to_text,
doc_to_target=self.doc_to_target,
doc_to_choice=self.doc_to_choice,
doc_to_visual=self.doc_to_visual,
delimiter=self.target_delimiter,
include_answer=False,
config=self.config,
)
)
# self.contexts.append(self.target_delimiter)

def already_have_image_token(self, image_token):
for context in self.contexts:
if context.already_have_image_token(image_token):
return True
return False

def get_text(self, *, image_tokens="<image>", lazy=True):
def get_text(self):
texts = []
vision = []
already_have_images = False
for context in self.contexts:
if isinstance(context, str) and image_tokens in context:
already_have_images = True
break
if already_have_images:
image_tokens = ""
for context in self.contexts:
if isinstance(context, LazyLoadedImages):
if isinstance(image_tokens, str):
if lazy:
texts.append(image_tokens)
else:
now_vision = context.get_images(self.doc_to_visual)
vision.extend(now_vision)
texts.append(image_tokens * len(now_vision))
else:
texts.append(image_tokens(context))
else:
texts.append(str(context))
if lazy:
return "".join(texts)
else:
return "".join(texts), vision
texts.append(str(context))
return "".join(texts)

def get_visions(self):
return sum([context.get_images(self.doc_to_visual) for context in self.contexts if isinstance(context, LazyLoadedImages)], start=[])
visions = []
for context in self.contexts:
visions.extend(context.get_visions())
return visions

def extend(self, context):
if isinstance(context, list):
Expand Down
23 changes: 11 additions & 12 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,21 +335,20 @@ def _collate(x):

conv = conv_templates[self.conv_template].copy()

num_image_tokens = 0
from lmms_eval.api.samplers import LazyLoadedImages, QAPairs

already_have_image_token = context.already_have_image_token(DEFAULT_IMAGE_TOKEN)

for obj in context.contexts:
if isinstance(obj, LazyLoadedImages):
num_image_tokens += obj.get_num_images()
elif isinstance(obj, QAPairs):
if num_image_tokens == 0:
question = obj.question
else:
question = " ".join(num_image_tokens * [DEFAULT_IMAGE_TOKEN]) + "\n" + obj.question
answer = obj.answer
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], answer)
num_image_tokens = 0
if already_have_image_token or obj.num_images() == 0:
question = obj.question
else:
question = " ".join(obj.num_images() * [DEFAULT_IMAGE_TOKEN]) + "\n" + obj.question
if context.description:
question = context.description + "\n" + question
answer = obj.answer
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], answer)

# conv.append_message(conv.roles[0], question)
# conv.append_message(conv.roles[1], None)
Expand Down

0 comments on commit d5ac624

Please sign in to comment.