Skip to content

Commit

Permalink
Refactor code to improve readability and add new features
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Apr 1, 2024
1 parent f180d08 commit 4e1188a
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 34 deletions.
91 changes: 63 additions & 28 deletions lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,65 @@
import datasets
from typing import Callable, Iterable, Optional


class LazyLoadedImages(object):
def __init__(self, data_frame, index):
self.data_frame: datasets.Dataset = data_frame
self.index = index

def get_images(self, doc_to_visual):
return doc_to_visual(self.data_frame[self.index])

from abc import ABC, abstractmethod

class ContextProcessors(ABC):
class ContextObject(ABC):
@abstractmethod
def process(self, question, answer = None):
def get_text(self):
raise NotImplementedError

def __call__(self, question, answer = None):
return self.process(question, answer)
def __str__(self):
return self.get_text()

class QAPairs(ContextObject):
def __init__(self, question: str, answer: Optional[str] = None, delimiter="\n", role_question: str = "USER: ", role_answer: str = "ASSISTANT: "):
self.question = question
self.answer = answer
self.delimiter = delimiter
self.role_question = role_question
self.role_answer = role_answer

def get_text(self):
if self.answer is None:
return self.role_question + self.question + self.delimiter
else:
return self.role_question + self.question + self.delimiter + self.role_answer + self.answer

class LazyLoadedImages(ContextObject):
def __init__(self, data_frame, index, doc_to_visual: Callable, image_tokens="<image>"):
self.data_frame: datasets.Dataset = data_frame
self.index = index
self.image_lens = None
self.images = None
self.doc_to_visual = doc_to_visual
self.image_tokens = image_tokens

def get_images(self, lazy_save=False):
if self.images is not None:
return self.images
images = self.doc_to_visual(self.data_frame[self.index])
self.image_lens = len(images)
if lazy_save:
self.images = images
return images

def get_num_images(self, lazy_save=False):
if self.image_lens is None:
images = self.get_images(self.doc_to_visual)
if lazy_save:
self.images = images
self.image_lens = len(images)
return self.image_lens

def clear(self, clear_all = False):
self.images = None
if clear_all:
self.image_lens = None

def get_text(self, lazy: bool = True):
if lazy:
return self.image_tokens
else:
return " ".join([self.image_tokens] * self.get_num_images())

class Context(object):
def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str = "\n", description=None):
Expand Down Expand Up @@ -52,33 +93,27 @@ def get_target(self, doc):
else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
)

def add_in_context_example(self, doc, data_frame=None, index=None, context_processors: Optional[ContextProcessors] = None):
def add_in_context_example(self, doc, data_frame=None, index=None):
question = self.get_question(doc)
if data_frame and index:
visual = LazyLoadedImages(data_frame, index)
visual = LazyLoadedImages(data_frame, index, self.doc_to_visual)
else:
visual = None
target = self.doc_to_target(doc)
if visual:
self.contexts.append(visual)
if context_processors:
question, target = context_processors(question, target)
self.contexts.append(question)
self.contexts.append(self.target_delimiter)
self.contexts.append(target)
self.contexts.append(QAPairs(question, target, self.target_delimiter))
self.contexts.append(self.few_shot_delimiter)

def add_question(self, doc, data_frame=None, index=None, context_processors: Optional[ContextProcessors] = None):
def add_question(self, doc, data_frame=None, index=None):
question = self.get_question(doc)
if data_frame and index:
visual = LazyLoadedImages(data_frame, index)
visual = LazyLoadedImages(data_frame, index, self.doc_to_visual)
else:
visual = None
if visual:
self.contexts.append(visual)
if context_processors:
question, _ = context_processors(question)
self.contexts.append(question)
self.contexts.append(QAPairs(question))
self.contexts.append(self.target_delimiter)

def get_text(self, *, image_tokens="<image>", lazy=True):
Expand All @@ -103,14 +138,14 @@ def get_text(self, *, image_tokens="<image>", lazy=True):
else:
texts.append(image_tokens(context))
else:
texts.append(context)
texts.append(str(context))
if lazy:
return "".join(texts)
else:
return "".join(texts), vision

def get_visions(self):
return [context.get_images(self.doc_to_visual) for context in self.contexts if isinstance(context, LazyLoadedImages)]
return sum([context.get_images(self.doc_to_visual) for context in self.contexts if isinstance(context, LazyLoadedImages)], start = [])

def extend(self, context):
if isinstance(context, list):
Expand Down
25 changes: 19 additions & 6 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ def _collate(x):
task = task[0]
split = split[0]
# batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N]
contexts_texts, batched_visuals = zip(*[context.get_text(image_tokens=DEFAULT_IMAGE_TOKEN, lazy=False) for context in contexts]) # [B, N]
batched_visuals = [context.get_visions() for context in contexts] # [B, N]
# contexts_texts, batched_visuals = zip(*[context.get_text(image_tokens=DEFAULT_IMAGE_TOKEN, lazy=False) for context in contexts]) # [B, N]
flattened_visuals = self.flatten(batched_visuals) # [B*N]
# batched_visuals = context.get_visions() # [B, N]
# flattened_visuals = contexts[0].get_visions() # [B*N]
Expand Down Expand Up @@ -316,7 +317,7 @@ def _collate(x):

question_input = []

for context in contexts_texts:
for context in contexts:
# if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context:
# """
# Three senarios:
Expand All @@ -332,12 +333,24 @@ def _collate(x):
# else:
# question = context

# conv = conv_templates[self.conv_template].copy()
conv = conv_templates[self.conv_template].copy()

num_image_tokens = 0
from lmms_eval.api.samplers import LazyLoadedImages, QAPairs
for obj in context.contexts:
if isinstance(obj, LazyLoadedImages):
num_image_tokens += obj.get_num_images()
elif isinstance(obj, QAPairs):
question = " ".join(num_image_tokens * [DEFAULT_IMAGE_TOKEN]) + "\n" + obj.question
answer = obj.answer
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], answer)


# conv.append_message(conv.roles[0], question)
# conv.append_message(conv.roles[1], None)
# prompt_question = conv.get_prompt()
# question_input.append(prompt_question)
question_input.append(contexts)
prompt_question = conv.get_prompt()
question_input.append(prompt_question)

# The above for loop has bugs. When there is no visuals, e.g. pure text,
# there will be no for loop execute resulting in an empty question_input (because no visuals)
Expand Down

0 comments on commit 4e1188a

Please sign in to comment.