Skip to content

Commit

Permalink
Refactor get_text method in Context class
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Mar 31, 2024
1 parent 961158e commit d8dcd0a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
20 changes: 16 additions & 4 deletions lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,29 @@ def add_question(self, doc, data_frame=None, index=None):
self.contexts.append(question)
self.contexts.append(self.target_delimiter)

def get_text(self, image_tokens="<image>"):
def get_text(self, *, image_tokens="<image>", lazy=True):
texts = []
vision = []
for context in self.contexts:
if isinstance(context, LazyLoadedImages):
if isinstance(image_tokens, str):
texts.append(image_tokens)
if lazy:
texts.append(image_tokens)
else:
now_vision = context.get_images(self.doc_to_visual)
vision.extend(now_vision)
texts.append(image_tokens * len(now_vision))
else:
texts.append(image_tokens(context))
else:
texts.append(context)
return "".join(texts)
if lazy:
return "".join(texts)
else:
return "".join(texts), vision

def get_visions(self):
return [context.get_images() for context in self.contexts if isinstance(context, LazyLoadedImages)]
return [context.get_images(self.doc_to_visual) for context in self.contexts if isinstance(context, LazyLoadedImages)]

def extend(self, context):
if isinstance(context, list):
Expand All @@ -94,6 +103,9 @@ def extend(self, context):
def append(self, context):
self.contexts.append(context)

def __str__(self):
return self.get_text()

def __lt__(self, other):
if not isinstance(other, Context):
return NotImplemented
Expand Down
9 changes: 6 additions & 3 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _collate(x):
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
toks = self.tok_encode(str(x[0]))
return -len(toks), x[0]

# we group requests by their generation_kwargs,
Expand All @@ -270,8 +270,11 @@ def _collate(x):
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N]
# batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N]
contexts_texts, batched_visuals = zip(*[context.get_text(lazy=False) for context in contexts]) # [B, N]
flattened_visuals = self.flatten(batched_visuals) # [B*N]
# batched_visuals = context.get_visions() # [B, N]
# flattened_visuals = contexts[0].get_visions() # [B*N]
############### for debugging ###################
# TODO: remove this block
# if len(visuals) > 1:
Expand Down Expand Up @@ -313,7 +316,7 @@ def _collate(x):

question_input = []

for visual, context in zip(batched_visuals, contexts):
for visual, context in zip(batched_visuals, contexts_texts):
if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context:
"""
Three senarios:
Expand Down

0 comments on commit d8dcd0a

Please sign in to comment.