From 8d8d692d0baa44d13b65effea536114db6aedbd5 Mon Sep 17 00:00:00 2001 From: kcz358 Date: Sat, 14 Sep 2024 04:44:41 +0000 Subject: [PATCH] Bring back process result pbar --- lmms_eval/evaluator.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index ee229c50..4ce17d46 100755 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -477,6 +477,9 @@ def evaluate( # iterate over different filters used for filter_key in task.instances[0].filtered_resps.keys(): doc_iterator = task.doc_iterator(rank=RANK, limit=limit, world_size=WORLD_SIZE) + doc_iterator_for_counting = itertools.islice(range(len(task.test_docs())), RANK, limit, WORLD_SIZE) if task.has_test_docs() else itertools.islice(range(len(task.validation_docs())), RANK, limit, WORLD_SIZE) + total_docs = sum(1 for _ in doc_iterator_for_counting) + pbar = tqdm(total=total_docs, desc=f"Postprocessing", disable=(RANK != 0)) for doc_id, doc in doc_iterator: requests = instances_by_doc_id[doc_id] metrics = task.process_results(doc, [req.filtered_resps[filter_key] for req in requests]) @@ -514,6 +517,9 @@ def evaluate( task_output.logged_samples.append(example) for metric, value in metrics.items(): task_output.sample_metrics[(metric, filter_key)].append(value) + pbar.update(1) + + pbar.close() if WORLD_SIZE > 1: # if multigpu, then gather data across all ranks to rank 0