Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE NOW] [Feat] Few shot eval #39

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2571960
few_shot dataset loading
pufanyi Mar 27, 2024
5dca4e1
Fix fewshot_config sampler constructor error
pufanyi Mar 27, 2024
5efc4cb
few shot dataset lazy load
pufanyi Mar 27, 2024
89e89e0
Fix sampler constructor in ConfigurableTask
pufanyi Mar 27, 2024
96129e0
Add same_as_eval parameter to FewShotDataset constructor
pufanyi Mar 28, 2024
93f56d9
fix sampling
pufanyi Mar 29, 2024
dd76e53
remove dulpilicated code
pufanyi Mar 29, 2024
a4d86da
fix a small bug
pufanyi Mar 29, 2024
70636cf
Merge commit '08d4151cea53725b3e016bf546b58bece8d51c38'
pufanyi Mar 29, 2024
c507389
Merge commit '08d4151cea53725b3e016bf546b58bece8d51c38' into pufanyi/…
pufanyi Mar 29, 2024
c9f759e
lint
pufanyi Mar 29, 2024
9979b03
update context sampler
pufanyi Mar 30, 2024
8d56856
Refactor get_question method in Context class
pufanyi Mar 30, 2024
b677bf3
fix
pufanyi Mar 30, 2024
a2e6e0e
Refactor get_context method in ContextSampler class
pufanyi Mar 30, 2024
f6fc367
Add description to Context class and update ConfigurableTask to inclu…
pufanyi Mar 30, 2024
826e5fe
lint
pufanyi Mar 30, 2024
e4cb4e6
fix visuals
pufanyi Mar 31, 2024
0e26684
lint
pufanyi Mar 31, 2024
b202dad
textvqa doc_to_target
pufanyi Mar 31, 2024
117082f
Merge remote-tracking branch 'origin/main' into pufanyi/few_shot
pufanyi Mar 31, 2024
c2e3a03
Update lmms_eval/api/samplers.py and lmms_eval/api/task.py
pufanyi Mar 31, 2024
83b3477
Add append method to Context class and refactor example handling in C…
pufanyi Mar 31, 2024
961158e
Refactor sorting logic in lmms_eval/utils.py
pufanyi Mar 31, 2024
d8dcd0a
Refactor get_text method in Context class
pufanyi Mar 31, 2024
3e2b24f
Refactor image token handling in LMMS evaluation code
pufanyi Mar 31, 2024
7b4b4fa
lint
pufanyi Mar 31, 2024
f180d08
Refactor context and question addition methods
pufanyi Mar 31, 2024
4e1188a
Refactor code to improve readability and add new features
pufanyi Apr 1, 2024
61236aa
lint
pufanyi Apr 1, 2024
aeaefc2
why so many bugs
pufanyi Apr 1, 2024
6812878
fix bug
pufanyi Apr 1, 2024
0b51d45
llava-textvqa done
pufanyi Apr 1, 2024
7607c28
Update construct_requests method signature
pufanyi Apr 1, 2024
d5ac624
make contexts lists only have qa pairs
pufanyi Apr 1, 2024
89513b7
add vila
pufanyi Apr 3, 2024
6ae82fe
Refactor LMMS API and models
pufanyi Apr 4, 2024
7c03d17
add docvqa fewshot
pufanyi Apr 17, 2024
60e0ba9
Merge remote-tracking branch 'origin/main' into pufanyi/few_shot
pufanyi Apr 17, 2024
c688a72
fix a small bug
pufanyi Apr 17, 2024
d1ba561
Fix condition for checking accelerator.num_processes in Llava class
pufanyi Apr 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lmms_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
except Exception as e:
traceback.print_exc()
eval_logger.error(f"Error during evaluation: {e}")
traceback.print_exc()
# traceback.print_exc()
results_list.append(None)

for args, results in zip(args_list, results_list):
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/api/instance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Literal, Tuple
from typing import Literal, Tuple, Iterable, Callable


@dataclass
Expand Down
199 changes: 174 additions & 25 deletions lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,166 @@
import datasets
from typing import Callable, Iterable, Optional


class LazyLoadedImages(object):
def __init__(self, data_frame, index):
self.data_frame: datasets.Dataset = data_frame
self.index = index

def get_images(self, doc_to_visual):
return doc_to_visual(self.data_frame[self.index])

from abc import ABC, abstractmethod

class ContextProcessors(ABC):
@abstractmethod
def process(self, question, answer = None):
raise NotImplementedError

def __call__(self, question, answer = None):
return self.process(question, answer)

class Context(object):
def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str = "\n", description=None):
self.task = task
self.config = task._config

self.doc_to_visual = self.task.doc_to_visual
self.doc_to_text = self.task.doc_to_text
self.doc_to_target = self.task.doc_to_target
self.doc_to_choice = self.task.doc_to_choice

self.target_delimiter = target_delimiter
self.few_shot_delimiter = few_shot_delimiter

self.contexts = []

if description:
self.add_description(description)

def add_description(self, description):
self.contexts = [description] + self.contexts

def get_question(self, doc):
text = self.doc_to_text(doc)
return text if (self.doc_to_choice is None or isinstance(text, str)) else self.doc_to_choice(doc)[text]

def get_target(self, doc):
return (
str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list
else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
)

def add_in_context_example(self, doc, data_frame=None, index=None, context_processors: Optional[ContextProcessors] = None):
question = self.get_question(doc)
if data_frame and index:
visual = LazyLoadedImages(data_frame, index)
else:
visual = None
target = self.doc_to_target(doc)
if visual:
self.contexts.append(visual)
if context_processors:
question, target = context_processors(question, target)
self.contexts.append(question)
self.contexts.append(self.target_delimiter)
self.contexts.append(target)
self.contexts.append(self.few_shot_delimiter)

def add_question(self, doc, data_frame=None, index=None, context_processors: Optional[ContextProcessors] = None):
question = self.get_question(doc)
if data_frame and index:
visual = LazyLoadedImages(data_frame, index)
else:
visual = None
if visual:
self.contexts.append(visual)
if context_processors:
question, _ = context_processors(question)
self.contexts.append(question)
self.contexts.append(self.target_delimiter)

def get_text(self, *, image_tokens="<image>", lazy=True):
texts = []
vision = []
already_have_images = False
for context in self.contexts:
if isinstance(context, str) and image_tokens in context:
already_have_images = True
break
if already_have_images:
image_tokens = ""
for context in self.contexts:
if isinstance(context, LazyLoadedImages):
if isinstance(image_tokens, str):
if lazy:
texts.append(image_tokens)
else:
now_vision = context.get_images(self.doc_to_visual)
vision.extend(now_vision)
texts.append(image_tokens * len(now_vision))
else:
texts.append(image_tokens(context))
else:
texts.append(context)
if lazy:
return "".join(texts)
else:
return "".join(texts), vision

def get_visions(self):
return [context.get_images(self.doc_to_visual) for context in self.contexts if isinstance(context, LazyLoadedImages)]

def extend(self, context):
if isinstance(context, list):
self.contexts.extend(context)
elif isinstance(context, Context):
self.contexts.extend(context.contexts)
else:
raise ValueError(f"Cannot extend context with object of type {type(context)}")

def append(self, context):
self.contexts.append(context)

def __str__(self):
return self.get_text()

def __lt__(self, other):
if not isinstance(other, Context):
return NotImplemented
return self.get_text() < other.get_text()


class FewShotDataset(object):
def __init__(self, dataset=None, *, dataset_path: str = None, dataset_name: str = None, split: str = None, dataset_kwargs: dict = None, same_as_eval: bool = False):
if dataset is not None and (dataset_path is not None or dataset_name is not None or split is not None or dataset_kwargs is not None):
raise ValueError("Cannot provide both `dataset` and other dataset arguments!")
self.dataset_path = dataset_path
self.dataset_name = dataset_name
self.split = split
self.dataset = dataset
self.dataset_kwargs = dataset_kwargs if dataset_kwargs is not None else {}
self.same_as_eval = same_as_eval
self.fewshot_indices = None

def get_dataset(self) -> datasets.Dataset:
if self.dataset is None:
self.dataset = datasets.load_dataset(path=self.dataset_path, name=self.dataset_name, split=self.split, download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, **self.dataset_kwargs)
if self.fewshot_indices:
self.dataset = self.dataset.select(self.fewshot_indices)
return self.dataset

def sample(self, n, rnd):
indices = rnd.sample(range(len(self.get_dataset())), n)
return indices, self.get_dataset().select(indices)

def __getitem__(self, item):
return self.get_dataset()[item]


class ContextSampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
def __init__(self, docs: FewShotDataset, task, fewshot_indices=None, rnd=None) -> None:
self.rnd = rnd
assert self.rnd, "must pass rnd to FewShotSampler!"

Expand All @@ -13,37 +174,25 @@ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.doc_to_target = self.task.doc_to_target
self.doc_to_choice = self.task.doc_to_choice

self.docs = docs # HF dataset split, provided by task._fewshot_docs()
self.docs: FewShotDataset = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices)
self.docs.fewshot_indices = fewshot_indices

def get_context(self, doc, num_fewshot):
def get_context(self, doc, num_fewshot) -> Context:
# draw an extra fewshot sample if using same split as evaluating on
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
n_samples = num_fewshot + 1 if self.docs.same_as_eval else num_fewshot

# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
indices, fewshotex = self.sample(n_samples)

# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

labeled_examples = (
self.fewshot_delimiter.join(
[
# TODO: is separating doc_to_text and doc_to_target by one space always desired?
(self.doc_to_text(doc) if (self.config.doc_to_choice is None or type(self.doc_to_text(doc)) is str) else self.doc_to_choice(doc)[self.doc_to_text(doc)])
+ self.target_delimiter
+ (
str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list
else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
)
for doc in selected_docs
]
)
+ self.fewshot_delimiter
)
selected_docs = [(idx, x) for idx, x in zip(indices, fewshotex) if x != doc][:num_fewshot]

labeled_examples = Context(self.task, self.fewshot_delimiter, self.target_delimiter)

for idx, doc in selected_docs:
labeled_examples.add_in_context_example(doc, self.docs, idx)

return labeled_examples

Expand All @@ -52,7 +201,7 @@ def sample(self, n):
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""

return self.rnd.sample(self.docs, n)
return self.docs.sample(n, self.rnd)


class FirstNSampler(ContextSampler):
Expand Down
61 changes: 38 additions & 23 deletions lmms_eval/api/task.py
pufanyi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from lmms_eval import utils
from lmms_eval.api import samplers
from lmms_eval.api.instance import Instance
from lmms_eval.api.samplers import FewShotDataset

from lmms_eval.filters import build_filter_ensemble
from lmms_eval.api.registry import (
Expand Down Expand Up @@ -64,7 +65,7 @@ class TaskConfig(dict):
training_split: str = None
validation_split: str = None
test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Callable = None
Expand Down Expand Up @@ -305,13 +306,13 @@ def fewshot_docs(self):
A iterable of any object, that doc_to_text can handle
"""
if self.has_training_docs():
return self.training_docs()
return FewShotDataset(self.training_docs())
elif self.has_validation_docs():
return self.validation_docs()
return FewShotDataset(self.validation_docs())
else:
if self.config.num_fewshot is not None:
eval_logger.warning("has_training_docs and has_validation_docs are False" ", using test_docs as fewshot_docs but this is not recommended.")
return self.test_docs()
return FewShotDataset(self.test_docs(), same_as_eval=True)

def _process_doc(self, doc):
"""
Expand Down Expand Up @@ -550,8 +551,20 @@ def __init__(self, model_name) -> None: # TODO no super() call here
self._filters.append(filter_pipeline)
else:
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
##########################################
pufanyi marked this conversation as resolved.
Show resolved Hide resolved
# TODO: for test, will delete later
if self.config.task == "textvqa_test":
pass
else:
pass
###########################################
if self.config.fewshot_config is not None:
self.sampler = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default")(list(self.fewshot_docs()), self, rnd=random.Random(1234))
try:
random_seed = self.config.fewshot_config.get("random_seed", 1234)
sampler_constructor = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default")
self.sampler = sampler_constructor(self.fewshot_docs(), self, rnd=random.Random(random_seed))
except Exception as e:
eval_logger.error(f"Error in fewshot_config: {e}")

if self.has_test_docs():
self.task_docs = self.test_docs()
Expand Down Expand Up @@ -742,8 +755,19 @@ def test_docs(self) -> datasets.Dataset:
return self.dataset[self.config.test_split]

def fewshot_docs(self):
if self.config.fewshot_split is not None:
return self.dataset[self.config.fewshot_split]

if "fewshot_dataset" in self.config.fewshot_config:
fewshot_dataset_config = self.config.fewshot_config["fewshot_dataset"]
if "dataset_path" not in fewshot_dataset_config:
fewshot_dataset_config["dataset_path"] = self.config.dataset_path
same_as_eval = self.config.dataset_path == fewshot_dataset_config["dataset_path"] and self.config.dataset_name == fewshot_dataset_config.get("dataset_name", None) and self.config.test_split == fewshot_dataset_config["split"]
return FewShotDataset(
dataset_path=fewshot_dataset_config["dataset_path"],
dataset_name=fewshot_dataset_config.get("dataset_name", None),
split=fewshot_dataset_config["split"],
dataset_kwargs=fewshot_dataset_config.get("dataset_kwargs", {}),
same_as_eval=same_as_eval,
)
else:
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
eval_logger.warning(f"Task '{self.config.task}': " "num_fewshot > 0 but fewshot_split is None. " "using preconfigured rule.")
Expand All @@ -761,23 +785,14 @@ def fewshot_context(self, doc_id, num_fewshot, split):
:returns: str
The fewshot context.
"""
from lmms_eval.api.samplers import Context

doc = self.dataset_no_image[split][doc_id]
if num_fewshot == 0:
# always prepend the (possibly empty) task description
labeled_examples = self.config.description
else:
labeled_examples = self.config.description + self.sampler.get_context(doc, num_fewshot)
example = self.doc_to_text(doc)
if type(example) == str:
return labeled_examples + example
elif type(example) == list:
return [labeled_examples + ex for ex in example]
elif type(example) == int:
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
labeled_examples = Context(self, self.config.fewshot_delimiter, self.config.target_delimiter, self.config.description)
if num_fewshot != 0:
labeled_examples.extend(self.sampler.get_context(doc, num_fewshot))
labeled_examples.add_question(doc)
return labeled_examples

def apply_filters(self):
if hasattr(self, "_filters"):
Expand Down
Loading
Loading