From fa4300aabe9bcf39e096c9da869d7b119b110221 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Sat, 14 Sep 2024 08:55:05 +0000 Subject: [PATCH 1/2] Refactor distributed gathering of logged samples and metrics --- lmms_eval/evaluator.py | 69 +++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index 4ce17d46..12db7fb2 100755 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -6,11 +6,13 @@ import random import sys import time +from collections import defaultdict from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import torch +import torch.distributed as dist from datasets import Image, Sequence from loguru import logger as eval_logger from tqdm import tqdm @@ -415,7 +417,7 @@ def evaluate( # chat_template=getattr(lm, "apply_chat_template") if apply_chat_template else None, # tokenizer_name=getattr(lm, "tokenizer_name", "") if apply_chat_template else "", ) - eval_logger.debug(f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}") + eval_logger.debug(f"Task: {task_output.task_name}; number of requests on this rank: {len(task._instances)}") if write_out: print_writeout(task) # aggregate Instances by LM method requested to get output. @@ -522,35 +524,52 @@ def evaluate( pbar.close() if WORLD_SIZE > 1: - # if multigpu, then gather data across all ranks to rank 0 - # first gather logged samples across all ranks for task_output in eval_tasks: if log_samples: - # for task_name, task_samples in list(samples.items()): - full_samples = [None] * WORLD_SIZE if RANK == 0 else None - per_rank_samples = [] - for sample in task_output.logged_samples: - per_rank_samples.append(sample) - - torch.distributed.gather_object( - obj=per_rank_samples, - object_gather_list=full_samples, - dst=0, - ) + # Gather logged samples + all_samples = [[] for _ in range(WORLD_SIZE)] if RANK == 0 else None + local_samples = task_output.logged_samples + + # Gather sample counts first + sample_counts = torch.tensor([len(local_samples)], dtype=torch.long, device="cuda") + all_counts = [torch.zeros(1, dtype=torch.long, device="cuda") for _ in range(WORLD_SIZE)] + dist.all_gather(all_counts, sample_counts) + + # Pad local samples to max count + max_count = max(count.item() for count in all_counts) + local_samples += [None] * (max_count - len(local_samples)) + + # Gather samples + dist.all_gather_object(all_samples, local_samples) if RANK == 0: - task_output.logged_samples = list(itertools.chain.from_iterable(full_samples)) - - # then collect metrics across all ranks - for metrics in task_output.sample_metrics: - metric_list = [None] * WORLD_SIZE if RANK == 0 else None - torch.distributed.gather_object( - obj=task_output.sample_metrics[metrics], - object_gather_list=metric_list, - dst=0, - ) + # Flatten and remove padding + task_output.logged_samples = [sample for samples, count in zip(all_samples, all_counts) for sample in samples[: count.item()]] + + # Gather metrics + all_metrics = defaultdict(list) + for metric_key, local_metrics in task_output.sample_metrics.items(): + # Gather metric counts + metric_counts = torch.tensor([len(local_metrics)], dtype=torch.long, device="cuda") + all_counts = [torch.zeros(1, dtype=torch.long, device="cuda") for _ in range(WORLD_SIZE)] + dist.all_gather(all_counts, metric_counts) + + # Pad local metrics to max count + max_count = max(count.item() for count in all_counts) + local_metrics += [None] * (max_count - len(local_metrics)) + + # Gather metrics + gathered_metrics = [None] * WORLD_SIZE + dist.all_gather_object(gathered_metrics, local_metrics) + if RANK == 0: - task_output.sample_metrics[metrics] = list(itertools.chain.from_iterable(metric_list)) + # Flatten and remove padding + all_metrics[metric_key] = [metric for metrics, count in zip(gathered_metrics, all_counts) for metric in metrics[: count.item()]] + + if RANK == 0: + task_output.sample_metrics = dict(all_metrics) + + dist.barrier() # Ensure all processes are synced before proceeding if RANK == 0: ### Aggregate results over all datapoints ### From c84c4883c8faa214f28a2886121f31a787093b7f Mon Sep 17 00:00:00 2001 From: Bo Li Date: Sat, 14 Sep 2024 12:54:52 +0000 Subject: [PATCH 2/2] Refactor caching module for LM evaluation harness --- lmms_eval/api/task.py | 199 +++++++++++++++++++++++++++------- lmms_eval/caching/__init__.py | 0 lmms_eval/caching/cache.py | 54 +++++++++ lmms_eval/evaluator.py | 67 +++++------- 4 files changed, 243 insertions(+), 77 deletions(-) create mode 100644 lmms_eval/caching/__init__.py create mode 100644 lmms_eval/caching/cache.py diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index 8b312756..01c49199 100755 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -48,6 +48,7 @@ get_metric_aggregation, is_higher_better, ) +from lmms_eval.caching.cache import load_from_cache, save_to_cache from lmms_eval.filters import build_filter_ensemble # HuggingfaceM4/NoCaps contains truncated image in test split @@ -376,7 +377,20 @@ def doc_to_target(self, doc): pass # @profile - def build_all_requests(self, limit=None, rank=None, world_size=None) -> None: + def build_all_requests( + self, + *, + limit: Union[int, None] = None, + rank: int = 0, + world_size: int = 1, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + chat_template: Optional[Callable] = None, + tokenizer_name: str = "", + ) -> None: """Build a set of Instances for a task, and store them in task.instances""" if self.has_test_docs(): docs = self.test_docs() @@ -387,35 +401,76 @@ def build_all_requests(self, limit=None, rank=None, world_size=None) -> None: else: assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" - eval_logger.info(f"Building contexts for task {self._config.task} on rank {rank}...") + # used with caching + og_limit = limit + + cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}" + cache_key += "-chat_template" if apply_chat_template else "" + cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else "" + cache_key += f"-system_prompt_hash{utils.hash_string(system_instruction)}" if system_instruction is not None else "" + cache_key += f"-tokenizer{tokenizer_name}" + + cached_instances = load_from_cache(file_name=cache_key) + + if cache_requests and cached_instances and not rewrite_requests_cache: + cached_instances = cached_instances[:limit] + + flattened_instances = [instance for instance_group in cached_instances for instance in instance_group] + + self._instances = flattened_instances + return + + eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...") + instances = [] - doc_id_iterator = utils.create_iterator([i for i in range(len(docs))], rank, world_size, limit) - doc_id_iterator, doc_id_iterator_counting = itertools.tee(doc_id_iterator) - total_docs = sum(1 for _ in doc_id_iterator_counting) - pbar = tqdm(total=total_docs, desc=f"Building context", disable=(rank != 0)) - for doc_id in doc_id_iterator: + + # process all documents when caching is specified for simplicity + if cache_requests and (not cached_instances or rewrite_requests_cache) and limit is not None: + limit = None + + doc_id_docs = list(self.doc_iterator(rank=rank, limit=limit, world_size=world_size)) + + num_docs = len(doc_id_docs) + + for doc_id, doc in tqdm( + doc_id_docs, + total=num_docs, + ): # sample fewshot context #TODO: need to offset doc_id by rank now! fewshot_ctx = self.fewshot_context( - doc_id, 0 if self.config.num_fewshot is None else self.config.num_fewshot, split - ) # TODO: avoid doc_id inconsistency between test and train, but wondering why selecting docs from test set, not train set + doc, + 0 if self.config.num_fewshot is None else self.config.num_fewshot, + system_instruction, + apply_chat_template, + fewshot_as_multiturn, + chat_template, + ) # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute - per_task_metadata = {"task": self.config["task"], "doc_id": doc_id, "repeats": self.config.repeats} - - if self.config.metadata and type(self.config.metadata) == dict: # TODO: temporary fix for metadata loading, ignore the list of dict type. + per_task_metadata = {"task": self.config["task"], "doc_id": doc_id, "repeats": self.config.repeats, "split": split} + if self.config.metadata: per_task_metadata.update(self.config.metadata) - inst = self.construct_requests(doc_id=doc_id, ctx=fewshot_ctx, metadata=per_task_metadata, split=split) + inst = self.construct_requests(doc_id=doc_id, ctx=fewshot_ctx, metadata=per_task_metadata) if not isinstance(inst, list): inst = [inst] - instances.extend(inst) - pbar.update(1) + instances.append(inst) + + # now flatten, this is to allow slicing to work with pickles - pbar.close() - self._instances = instances - assert len(self._instances) != 0, "task.build_requests() did not find any docs!" + sliced_instances = instances[:og_limit] + + flattened_instances = [instance for instance_group in sliced_instances for instance in instance_group] + + self._instances = flattened_instances + + if len(self._instances) == 0: + raise ValueError("task.build_requests() did not find any docs!") + + if cache_requests and (not cached_instances or rewrite_requests_cache): + save_to_cache(file_name=cache_key, obj=instances) @abc.abstractmethod def construct_requests(self, doc_id, ctx, **kwargs): @@ -1017,36 +1072,104 @@ def fewshot_docs(self): return super().fewshot_docs() @utils.positional_deprecated - def fewshot_context(self, doc_id, num_fewshot, split): + def fewshot_context( + self, + doc: str, + num_fewshot: int, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + chat_template: Optional[Callable] = None, + ) -> str: """Returns a fewshot context string that is made up of a prepended description (if provided), the `num_fewshot` number of examples, and an appended prompt example. - :param doc_id: str - The document id as returned from training_docs, validation_docs, or test_docs. + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. :param num_fewshot: int The number of fewshot examples to provide in the returned context string. + :param system_instruction: str + System instruction to be applied to the prompt. + :param apply_chat_template: bool + Whether to apply the chat template to the fewshot context. + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param chat_template: + callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string. :returns: str The fewshot context. """ - doc = self.dataset_no_image[split][doc_id] - if num_fewshot == 0: - # always prepend the (possibly empty) task description - labeled_examples = self.config.description + if apply_chat_template: + labeled_examples = [] else: - labeled_examples = self.config.description + self.sampler.get_context(doc, num_fewshot) + labeled_examples = "" - example = self.doc_to_text(doc) - if type(example) == str: - return labeled_examples + example - elif type(example) == list: - return [labeled_examples + ex for ex in example] - elif type(example) == int: - if self.config.doc_to_choice is not None: - choices = self.doc_to_choice(doc) - return labeled_examples + choices[example] + # get task description + if description := self.config.description: + description = utils.apply_template(self.config.description, doc) + + # create system prompt based on the provided system instruction and description + if system_instruction is not None and description: + system_prompt = f"{system_instruction}{self.sampler.fewshot_delimiter}{description}" + elif system_instruction is not None: + system_prompt = system_instruction + elif description: + system_prompt = description + else: + system_prompt = "" + + # add system prompt if specified + if system_prompt: + if apply_chat_template: + labeled_examples.append({"role": "system", "content": system_prompt}) + else: + labeled_examples = system_prompt + + # if few-shot - append examples after the system prompt + if num_fewshot > 0: + if apply_chat_template: + labeled_examples.extend(self.sampler.get_chat_context(doc, num_fewshot, fewshot_as_multiturn)) else: - return labeled_examples + str(example) + labeled_examples += self.sampler.get_context(doc, num_fewshot) + + example = self.doc_to_text(doc) + if apply_chat_template: + if self.multiple_input: + return chat_template(labeled_examples) + if isinstance(example, str): + self.append_target_question(labeled_examples, example, fewshot_as_multiturn) + # for loglikelihood create a list of questions with appended choices + elif isinstance(example, list): + labeled_examples_list = [] + # copy chat history for each example and append the answer + for ex in example: + chat = deepcopy(labeled_examples) + self.append_target_question(chat, ex, fewshot_as_multiturn) + labeled_examples_list.append(chat_template(chat)) + return labeled_examples_list + # if example is an integer, append the choice or convert to string + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + self.append_target_question(labeled_examples, choices[example], fewshot_as_multiturn) + else: + self.append_target_question(labeled_examples, str(example), fewshot_as_multiturn) + # return lm.apply_chat_template(labeled_examples) + return chat_template(labeled_examples) + else: + if self.multiple_input: + return labeled_examples + if isinstance(example, str): + return labeled_examples + example + elif isinstance(example, list): + return [labeled_examples + ex for ex in example] + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + return labeled_examples + choices[example] + else: + return labeled_examples + str(example) def apply_filters(self): if hasattr(self, "_filters"): @@ -1198,8 +1321,8 @@ def doc_to_choice(self, doc: Any) -> List[str]: raise TypeError def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Instance], Instance]: - split = kwargs.get("split") - kwargs.pop("split") + split = kwargs.get("metadata").get("split") + # kwargs.pop("split") if self.OUTPUT_TYPE == "loglikelihood": arguments = (ctx, self.doc_to_target, self.doc_to_visual, doc_id, self.config.task, split) elif self.OUTPUT_TYPE == "multiple_choice": diff --git a/lmms_eval/caching/__init__.py b/lmms_eval/caching/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lmms_eval/caching/cache.py b/lmms_eval/caching/cache.py new file mode 100644 index 00000000..196333a6 --- /dev/null +++ b/lmms_eval/caching/cache.py @@ -0,0 +1,54 @@ +import hashlib +import os + +import dill + +from lmms_eval.utils import eval_logger + +MODULE_DIR = os.path.dirname(os.path.realpath(__file__)) + +OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH") + + +PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache" + +# This should be sufficient for uniqueness +HASH_INPUT = "EleutherAI-lm-evaluation-harness" + +HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest() + +FILE_SUFFIX = f".{HASH_PREFIX}.pickle" + + +def load_from_cache(file_name): + try: + path = f"{PATH}/{file_name}{FILE_SUFFIX}" + + with open(path, "rb") as file: + cached_task_dict = dill.loads(file.read()) + return cached_task_dict + + except Exception: + eval_logger.debug(f"{file_name} is not cached, generating...") + pass + + +def save_to_cache(file_name, obj): + if not os.path.exists(PATH): + os.mkdir(PATH) + + file_path = f"{PATH}/{file_name}{FILE_SUFFIX}" + + eval_logger.debug(f"Saving {file_path} to cache...") + with open(file_path, "wb") as file: + file.write(dill.dumps(obj)) + + +# NOTE the "key" param is to allow for flexibility +def delete_cache(key: str = ""): + files = os.listdir(PATH) + + for file in files: + if file.startswith(key) and file.endswith(FILE_SUFFIX): + file_path = f"{PATH}/{file}" + os.unlink(file_path) diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index 12db7fb2..aa8dc032 100755 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -255,6 +255,10 @@ def _adjust_config(task_dict): cli_args=cli_args, ) + if hasattr(lm, "_model"): + del lm._model + torch.cuda.empty_cache() + if lm.rank == 0: if isinstance(model, str): model_name = model @@ -524,50 +528,35 @@ def evaluate( pbar.close() if WORLD_SIZE > 1: + # if multigpu, then gather data across all ranks to rank 0 + # first gather logged samples across all ranks for task_output in eval_tasks: if log_samples: - # Gather logged samples - all_samples = [[] for _ in range(WORLD_SIZE)] if RANK == 0 else None - local_samples = task_output.logged_samples - - # Gather sample counts first - sample_counts = torch.tensor([len(local_samples)], dtype=torch.long, device="cuda") - all_counts = [torch.zeros(1, dtype=torch.long, device="cuda") for _ in range(WORLD_SIZE)] - dist.all_gather(all_counts, sample_counts) - - # Pad local samples to max count - max_count = max(count.item() for count in all_counts) - local_samples += [None] * (max_count - len(local_samples)) - - # Gather samples - dist.all_gather_object(all_samples, local_samples) + # for task_name, task_samples in list(samples.items()): + full_samples = [None] * WORLD_SIZE if RANK == 0 else None + per_rank_samples = [] + for sample in task_output.logged_samples: + per_rank_samples.append(sample) + + torch.distributed.gather_object( + obj=per_rank_samples, + object_gather_list=full_samples, + dst=0, + ) if RANK == 0: - # Flatten and remove padding - task_output.logged_samples = [sample for samples, count in zip(all_samples, all_counts) for sample in samples[: count.item()]] - - # Gather metrics - all_metrics = defaultdict(list) - for metric_key, local_metrics in task_output.sample_metrics.items(): - # Gather metric counts - metric_counts = torch.tensor([len(local_metrics)], dtype=torch.long, device="cuda") - all_counts = [torch.zeros(1, dtype=torch.long, device="cuda") for _ in range(WORLD_SIZE)] - dist.all_gather(all_counts, metric_counts) - - # Pad local metrics to max count - max_count = max(count.item() for count in all_counts) - local_metrics += [None] * (max_count - len(local_metrics)) - - # Gather metrics - gathered_metrics = [None] * WORLD_SIZE - dist.all_gather_object(gathered_metrics, local_metrics) - + task_output.logged_samples = list(itertools.chain.from_iterable(full_samples)) + + # then collect metrics across all ranks + for metrics in task_output.sample_metrics: + metric_list = [None] * WORLD_SIZE if RANK == 0 else None + torch.distributed.gather_object( + obj=task_output.sample_metrics[metrics], + object_gather_list=metric_list, + dst=0, + ) if RANK == 0: - # Flatten and remove padding - all_metrics[metric_key] = [metric for metrics, count in zip(gathered_metrics, all_counts) for metric in metrics[: count.item()]] - - if RANK == 0: - task_output.sample_metrics = dict(all_metrics) + task_output.sample_metrics[metrics] = list(itertools.chain.from_iterable(metric_list)) dist.barrier() # Ensure all processes are synced before proceeding