diff --git a/lmms_eval/api/samplers.py b/lmms_eval/api/samplers.py index e710fed7..c6c519e5 100644 --- a/lmms_eval/api/samplers.py +++ b/lmms_eval/api/samplers.py @@ -68,20 +68,29 @@ def add_question(self, doc, data_frame=None, index=None): self.contexts.append(question) self.contexts.append(self.target_delimiter) - def get_text(self, image_tokens=""): + def get_text(self, *, image_tokens="", lazy=True): texts = [] + vision = [] for context in self.contexts: if isinstance(context, LazyLoadedImages): if isinstance(image_tokens, str): - texts.append(image_tokens) + if lazy: + texts.append(image_tokens) + else: + now_vision = context.get_images(self.doc_to_visual) + vision.extend(now_vision) + texts.append(image_tokens * len(now_vision)) else: texts.append(image_tokens(context)) else: texts.append(context) - return "".join(texts) + if lazy: + return "".join(texts) + else: + return "".join(texts), vision def get_visions(self): - return [context.get_images() for context in self.contexts if isinstance(context, LazyLoadedImages)] + return [context.get_images(self.doc_to_visual) for context in self.contexts if isinstance(context, LazyLoadedImages)] def extend(self, context): if isinstance(context, list): @@ -94,6 +103,9 @@ def extend(self, context): def append(self, context): self.contexts.append(context) + def __str__(self): + return self.get_text() + def __lt__(self, other): if not isinstance(other, Context): return NotImplemented diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 3a9c8407..79ca0853 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -256,7 +256,7 @@ def _collate(x): # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end - toks = self.tok_encode(x[0]) + toks = self.tok_encode(str(x[0])) return -len(toks), x[0] # we group requests by their generation_kwargs, @@ -270,8 +270,11 @@ def _collate(x): contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) task = task[0] split = split[0] - batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N] + # batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N] + contexts_texts, batched_visuals = zip(*[context.get_text(lazy=False) for context in contexts]) # [B, N] flattened_visuals = self.flatten(batched_visuals) # [B*N] + # batched_visuals = context.get_visions() # [B, N] + # flattened_visuals = contexts[0].get_visions() # [B*N] ############### for debugging ################### # TODO: remove this block # if len(visuals) > 1: @@ -313,7 +316,7 @@ def _collate(x): question_input = [] - for visual, context in zip(batched_visuals, contexts): + for visual, context in zip(batched_visuals, contexts_texts): if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context: """ Three senarios: