Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor distributed gathering of logged samples and metrics #253

Merged
merged 2 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 161 additions & 38 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
get_metric_aggregation,
is_higher_better,
)
from lmms_eval.caching.cache import load_from_cache, save_to_cache
from lmms_eval.filters import build_filter_ensemble

# HuggingfaceM4/NoCaps contains truncated image in test split
Expand Down Expand Up @@ -376,7 +377,20 @@ def doc_to_target(self, doc):
pass

# @profile
def build_all_requests(self, limit=None, rank=None, world_size=None) -> None:
def build_all_requests(
self,
*,
limit: Union[int, None] = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
tokenizer_name: str = "",
) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""
if self.has_test_docs():
docs = self.test_docs()
Expand All @@ -387,35 +401,76 @@ def build_all_requests(self, limit=None, rank=None, world_size=None) -> None:
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"

eval_logger.info(f"Building contexts for task {self._config.task} on rank {rank}...")
# used with caching
og_limit = limit

cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
cache_key += "-chat_template" if apply_chat_template else ""
cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
cache_key += f"-system_prompt_hash{utils.hash_string(system_instruction)}" if system_instruction is not None else ""
cache_key += f"-tokenizer{tokenizer_name}"

cached_instances = load_from_cache(file_name=cache_key)

if cache_requests and cached_instances and not rewrite_requests_cache:
cached_instances = cached_instances[:limit]

flattened_instances = [instance for instance_group in cached_instances for instance in instance_group]

self._instances = flattened_instances
return

eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")

instances = []
doc_id_iterator = utils.create_iterator([i for i in range(len(docs))], rank, world_size, limit)
doc_id_iterator, doc_id_iterator_counting = itertools.tee(doc_id_iterator)
total_docs = sum(1 for _ in doc_id_iterator_counting)
pbar = tqdm(total=total_docs, desc=f"Building context", disable=(rank != 0))
for doc_id in doc_id_iterator:

# process all documents when caching is specified for simplicity
if cache_requests and (not cached_instances or rewrite_requests_cache) and limit is not None:
limit = None

doc_id_docs = list(self.doc_iterator(rank=rank, limit=limit, world_size=world_size))

num_docs = len(doc_id_docs)

for doc_id, doc in tqdm(
doc_id_docs,
total=num_docs,
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
doc_id, 0 if self.config.num_fewshot is None else self.config.num_fewshot, split
) # TODO: avoid doc_id inconsistency between test and train, but wondering why selecting docs from test set, not train set
doc,
0 if self.config.num_fewshot is None else self.config.num_fewshot,
system_instruction,
apply_chat_template,
fewshot_as_multiturn,
chat_template,
)

# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
per_task_metadata = {"task": self.config["task"], "doc_id": doc_id, "repeats": self.config.repeats}

if self.config.metadata and type(self.config.metadata) == dict: # TODO: temporary fix for metadata loading, ignore the list of dict type.
per_task_metadata = {"task": self.config["task"], "doc_id": doc_id, "repeats": self.config.repeats, "split": split}
if self.config.metadata:
per_task_metadata.update(self.config.metadata)

inst = self.construct_requests(doc_id=doc_id, ctx=fewshot_ctx, metadata=per_task_metadata, split=split)
inst = self.construct_requests(doc_id=doc_id, ctx=fewshot_ctx, metadata=per_task_metadata)

if not isinstance(inst, list):
inst = [inst]

instances.extend(inst)
pbar.update(1)
instances.append(inst)

# now flatten, this is to allow slicing to work with pickles

pbar.close()
self._instances = instances
assert len(self._instances) != 0, "task.build_requests() did not find any docs!"
sliced_instances = instances[:og_limit]

flattened_instances = [instance for instance_group in sliced_instances for instance in instance_group]

self._instances = flattened_instances

if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")

if cache_requests and (not cached_instances or rewrite_requests_cache):
save_to_cache(file_name=cache_key, obj=instances)

@abc.abstractmethod
def construct_requests(self, doc_id, ctx, **kwargs):
Expand Down Expand Up @@ -1017,36 +1072,104 @@ def fewshot_docs(self):
return super().fewshot_docs()

@utils.positional_deprecated
def fewshot_context(self, doc_id, num_fewshot, split):
def fewshot_context(
self,
doc: str,
num_fewshot: int,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
) -> str:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.

:param doc_id: str
The document id as returned from training_docs, validation_docs, or test_docs.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param system_instruction: str
System instruction to be applied to the prompt.
:param apply_chat_template: bool
Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param chat_template:
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
:returns: str
The fewshot context.
"""
doc = self.dataset_no_image[split][doc_id]

if num_fewshot == 0:
# always prepend the (possibly empty) task description
labeled_examples = self.config.description
if apply_chat_template:
labeled_examples = []
else:
labeled_examples = self.config.description + self.sampler.get_context(doc, num_fewshot)
labeled_examples = ""

example = self.doc_to_text(doc)
if type(example) == str:
return labeled_examples + example
elif type(example) == list:
return [labeled_examples + ex for ex in example]
elif type(example) == int:
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
# get task description
if description := self.config.description:
description = utils.apply_template(self.config.description, doc)

# create system prompt based on the provided system instruction and description
if system_instruction is not None and description:
system_prompt = f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
elif system_instruction is not None:
system_prompt = system_instruction
elif description:
system_prompt = description
else:
system_prompt = ""

# add system prompt if specified
if system_prompt:
if apply_chat_template:
labeled_examples.append({"role": "system", "content": system_prompt})
else:
labeled_examples = system_prompt

# if few-shot - append examples after the system prompt
if num_fewshot > 0:
if apply_chat_template:
labeled_examples.extend(self.sampler.get_chat_context(doc, num_fewshot, fewshot_as_multiturn))
else:
return labeled_examples + str(example)
labeled_examples += self.sampler.get_context(doc, num_fewshot)

example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(labeled_examples, example, fewshot_as_multiturn)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
labeled_examples_list = []
# copy chat history for each example and append the answer
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat))
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
self.append_target_question(labeled_examples, choices[example], fewshot_as_multiturn)
else:
self.append_target_question(labeled_examples, str(example), fewshot_as_multiturn)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
else:
if self.multiple_input:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)

def apply_filters(self):
if hasattr(self, "_filters"):
Expand Down Expand Up @@ -1198,8 +1321,8 @@ def doc_to_choice(self, doc: Any) -> List[str]:
raise TypeError

def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Instance], Instance]:
split = kwargs.get("split")
kwargs.pop("split")
split = kwargs.get("metadata").get("split")
# kwargs.pop("split")
if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target, self.doc_to_visual, doc_id, self.config.task, split)
elif self.OUTPUT_TYPE == "multiple_choice":
Expand Down
Empty file added lmms_eval/caching/__init__.py
Empty file.
54 changes: 54 additions & 0 deletions lmms_eval/caching/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import hashlib
import os

import dill

from lmms_eval.utils import eval_logger

MODULE_DIR = os.path.dirname(os.path.realpath(__file__))

OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")


PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"

# This should be sufficient for uniqueness
HASH_INPUT = "EleutherAI-lm-evaluation-harness"

HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()

FILE_SUFFIX = f".{HASH_PREFIX}.pickle"


def load_from_cache(file_name):
try:
path = f"{PATH}/{file_name}{FILE_SUFFIX}"

with open(path, "rb") as file:
cached_task_dict = dill.loads(file.read())
return cached_task_dict

except Exception:
eval_logger.debug(f"{file_name} is not cached, generating...")
pass


def save_to_cache(file_name, obj):
if not os.path.exists(PATH):
os.mkdir(PATH)

file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"

eval_logger.debug(f"Saving {file_path} to cache...")
with open(file_path, "wb") as file:
file.write(dill.dumps(obj))


# NOTE the "key" param is to allow for flexibility
def delete_cache(key: str = ""):
files = os.listdir(PATH)

for file in files:
if file.startswith(key) and file.endswith(FILE_SUFFIX):
file_path = f"{PATH}/{file}"
os.unlink(file_path)
10 changes: 9 additions & 1 deletion lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import random
import sys
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Union

import numpy as np
import torch
import torch.distributed as dist
from datasets import Image, Sequence
from loguru import logger as eval_logger
from tqdm import tqdm
Expand Down Expand Up @@ -253,6 +255,10 @@ def _adjust_config(task_dict):
cli_args=cli_args,
)

if hasattr(lm, "_model"):
del lm._model
torch.cuda.empty_cache()

if lm.rank == 0:
if isinstance(model, str):
model_name = model
Expand Down Expand Up @@ -415,7 +421,7 @@ def evaluate(
# chat_template=getattr(lm, "apply_chat_template") if apply_chat_template else None,
# tokenizer_name=getattr(lm, "tokenizer_name", "") if apply_chat_template else "",
)
eval_logger.debug(f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}")
eval_logger.debug(f"Task: {task_output.task_name}; number of requests on this rank: {len(task._instances)}")
if write_out:
print_writeout(task)
# aggregate Instances by LM method requested to get output.
Expand Down Expand Up @@ -552,6 +558,8 @@ def evaluate(
if RANK == 0:
task_output.sample_metrics[metrics] = list(itertools.chain.from_iterable(metric_list))

dist.barrier() # Ensure all processes are synced before proceeding

if RANK == 0:
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
Expand Down
Loading