Skip to content

Commit

Permalink
[Feat] add regression test and change saving logic related to `output…
Browse files Browse the repository at this point in the history
…_path` (#259)

* feat: add new ouput_path saving logic and add evaluation tracker to manage samples saving process

* add: regression test

* add: regression test

* clean: unuseful code

* 🚫 Remove unused import for cleaner code

Eliminated the commented-out import statement for WandbLogger to tidy up the code and enhance readability. This helps maintain focus on active components and prevents confusion over unused code. A cleaner structure contributes to better maintainability in the long run.

No functional changes were made, just a step towards a more streamlined codebase.
  • Loading branch information
Luodian authored Sep 17, 2024
1 parent a2a881c commit e20d5d6
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 50 deletions.
22 changes: 22 additions & 0 deletions docs/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,26 @@ pip install httpx==0.23.3
pip install protobuf==3.20
```

## Regression Test

Now after each PR, we need to run the regression test to make sure the performance of the model is not degraded.

```bash
python3 tools/regression.py
```

```bash
Already on 'dev/fix_output_path'

|task|llava-onevision-qwen2-0.5b-ov|
|--|--|
|ocrbench (dev/fix_output_path)|0.70 ± 0.70|
|mmmu_val (dev/fix_output_path)|50.00 ± 50.00|
|ai2d (dev/fix_output_path)|50.00 ± 50.00|
|muirbench (dev/fix_output_path)|12.50 ± 12.50|
|videomme (dev/fix_output_path)|2500.00 ± 2500.00|

|branch|runtime|%|
|--|--|--|
|dev/fix_output_path|87.7s|100%|
```
64 changes: 30 additions & 34 deletions lmms_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from lmms_eval.api.registry import ALL_TASKS
from lmms_eval.evaluator import request_caching_arg_to_dict
from lmms_eval.loggers import EvaluationTracker, WandbLogger

# from lmms_eval.logging_utils import WandbLogger
from lmms_eval.tasks import TaskManager
from lmms_eval.utils import (
handle_non_serializable,
Expand Down Expand Up @@ -230,7 +228,7 @@ def parse_eval_args() -> argparse.Namespace:
parser.add_argument(
"--timezone",
default="Asia/Singapore",
help="Timezone for datetime string, e.g. Asia/Singapore, America/New_York, America/Los_Angeles",
help="Timezone for datetime string, e.g. Asia/Singapore, America/New_York, America/Los_Angeles. You can check the full list via `import pytz; print(pytz.common_timezones)`",
)
parser.add_argument(
"--hf_hub_log_args",
Expand Down Expand Up @@ -349,7 +347,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
for args, results in zip(args_list, results_list):
# cli_evaluate will return none if the process is not the main process (rank 0)
if results is not None:
print_results(args, results)
print(f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"batch_size: {args.batch_size}")
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))

if args.wandb_args:
wandb_logger.run.finish()
Expand Down Expand Up @@ -462,22 +463,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:

eval_logger.info(f"Selected Tasks: {task_names}")
request_caching_args = request_caching_arg_to_dict(cache_requests=args.cache_requests)

# set datetime before evaluation
datetime_str = utils.get_datetime_str(timezone=args.timezone)
if args.output_path:
if args.log_samples_suffix and len(args.log_samples_suffix) > 15:
eval_logger.warning("The suffix for log_samples is too long. It is recommended to keep it under 15 characters.")
args.log_samples_suffix = args.log_samples_suffix[:5] + "..." + args.log_samples_suffix[-5:]

hash_input = f"{args.model_args}".encode("utf-8")
hash_output = hashlib.sha256(hash_input).hexdigest()[:6]
path = Path(args.output_path)
path = path.expanduser().resolve().joinpath(f"{datetime_str}_{args.log_samples_suffix}_{args.model}_model_args_{hash_output}")
args.output_path = path

elif args.log_samples and not args.output_path:
assert args.output_path, "Specify --output_path"

results = evaluator.simple_evaluate(
model=args.model,
Expand Down Expand Up @@ -505,6 +491,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
cli_args=args,
datetime_str=datetime_str,
**request_caching_args,
)

Expand All @@ -517,21 +504,30 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
if args.show_config:
print(dumped)

if args.output_path:
args.output_path.mkdir(parents=True, exist_ok=True)
result_file_path = path.joinpath("results.json")
if result_file_path.exists():
eval_logger.warning(f"Output file {result_file_path} already exists and will be overwritten.")

result_file_path.open("w").write(dumped)
if args.log_samples:
for task_name, config in results["configs"].items():
filename = args.output_path.joinpath(f"{task_name}.json")
# Structure the data with 'args' and 'logs' keys
data_to_dump = {"args": vars(args), "model_configs": config, "logs": sorted(samples[task_name], key=lambda x: x["doc_id"]), "time": datetime_str}
samples_dumped = json.dumps(data_to_dump, indent=4, default=_handle_non_serializable, ensure_ascii=False)
filename.open("w", encoding="utf-8").write(samples_dumped)
eval_logger.info(f"Saved samples to {filename}")
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

# Add W&B logging
if args.wandb_args:
try:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
if args.log_samples:
wandb_logger.log_eval_samples(samples)
except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

evaluation_tracker.save_results_aggregated(results=results, samples=samples if args.log_samples else None, datetime_str=datetime_str)

if args.log_samples:
for task_name, config in results["configs"].items():
evaluation_tracker.save_results_samples(task_name=task_name, samples=samples[task_name])

if evaluation_tracker.push_results_to_hub or evaluation_tracker.push_samples_to_hub:
evaluation_tracker.recreate_metadata_card()

if args.wandb_args:
# Tear down wandb run once all the logging is done.
wandb_logger.run.finish()

return results, samples
return None, None
Expand Down
3 changes: 2 additions & 1 deletion lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def simple_evaluate(
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
datetime_str: str = get_datetime_str(),
cli_args=None,
):
"""Instantiate and evaluate a model on a list of tasks.
Expand Down Expand Up @@ -292,7 +293,7 @@ def _adjust_config(task_dict):
}
)
results["git_hash"] = get_git_commit_hash()
results["date"] = get_datetime_str()
results["date"] = datetime_str
# add_env_info(results) # additional environment info to results
# add_tokenizer_info(results, lm) # additional info about tokenizer
return results
Expand Down
22 changes: 12 additions & 10 deletions lmms_eval/loggers/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from lmms_eval.utils import (
eval_logger,
get_datetime_str,
get_file_datetime,
get_file_task_name,
get_results_filenames,
Expand Down Expand Up @@ -154,7 +155,7 @@ def __init__(
eval_logger.warning(f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'.")

if hub_repo_name == "":
details_repo_name = details_repo_name if details_repo_name != "" else "lm-eval-results"
details_repo_name = details_repo_name if details_repo_name != "" else "lmms-eval-results"
results_repo_name = results_repo_name if results_repo_name != "" else details_repo_name
else:
details_repo_name = hub_repo_name
Expand All @@ -170,13 +171,15 @@ def save_results_aggregated(
self,
results: dict,
samples: dict,
datetime_str: str,
) -> None:
"""
Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested.
Args:
results (dict): The aggregated results to save.
samples (dict): The samples results to save.
datetime_str (str): The datetime string to use for the results file.
"""
self.general_config_tracker.log_end_time()

Expand Down Expand Up @@ -205,8 +208,8 @@ def save_results_aggregated(
path = path.joinpath(self.general_config_tracker.model_name_sanitized)
path.mkdir(parents=True, exist_ok=True)

self.date_id = datetime.now().isoformat().replace(":", "-")
file_results_aggregated = path.joinpath(f"results_{self.date_id}.json")
self.date_id = datetime_str.replace(":", "-")
file_results_aggregated = path.joinpath(f"{self.date_id}_results.json")
file_results_aggregated.open("w", encoding="utf-8").write(dumped)

if self.api and self.push_results_to_hub:
Expand All @@ -219,10 +222,10 @@ def save_results_aggregated(
)
self.api.upload_file(
repo_id=repo_id,
path_or_fileobj=str(path.joinpath(f"results_{self.date_id}.json")),
path_or_fileobj=str(path.joinpath(f"{self.date_id}_results.json")),
path_in_repo=os.path.join(
self.general_config_tracker.model_name,
f"results_{self.date_id}.json",
f"{self.date_id}_results.json",
),
repo_type="dataset",
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
Expand Down Expand Up @@ -255,18 +258,17 @@ def save_results_samples(
path = path.joinpath(self.general_config_tracker.model_name_sanitized)
path.mkdir(parents=True, exist_ok=True)

file_results_samples = path.joinpath(f"samples_{task_name}_{self.date_id}.jsonl")
file_results_samples = path.joinpath(f"{self.date_id}_samples_{task_name}.jsonl")

for sample in samples:
# we first need to sanitize arguments and resps
# otherwise we won't be able to load the dataset
# using the datasets library
arguments = {}
for i, arg in enumerate(sample["arguments"]):
arguments[f"gen_args_{i}"] = {}
for j, tmp in enumerate(arg):
arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
for key, value in enumerate(sample["arguments"][1]): # update metadata into args
arguments[key] = value

sample["input"] = sample["arguments"][0]
sample["resps"] = sanitize_list(sample["resps"])
sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
sample["arguments"] = arguments
Expand Down
14 changes: 9 additions & 5 deletions lmms_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import gc
from itertools import islice

import numpy as np
import pytz
import torch
import transformers
Expand Down Expand Up @@ -238,11 +239,14 @@ def get_file_datetime(filename: str) -> str:
return filename[filename.rfind("_") + 1 :].replace(".jsonl", "")


def sanitize_model_name(model_name: str) -> str:
def sanitize_model_name(model_name: str, full_path: bool = False) -> str:
"""
Given the model name, returns a sanitized version of it.
"""
return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name)
if full_path:
return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name)
else:
return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name.split("/")[-1])


def sanitize_task_name(task_name: str) -> str:
Expand All @@ -263,14 +267,14 @@ def get_results_filenames(filenames: List[str]) -> List[str]:
"""
Extracts filenames that correspond to aggregated results.
"""
return [f for f in filenames if "/results_" in f and ".json" in f]
return [f for f in filenames if "results" in f and ".json" in f]


def get_sample_results_filenames(filenames: List[str]) -> List[str]:
"""
Extracts filenames that correspond to sample results.
"""
return [f for f in filenames if "/samples_" in f and ".json" in f]
return [f for f in filenames if "samples" in f and ".json" in f]


def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
Expand Down Expand Up @@ -588,7 +592,7 @@ def get_datetime_str(timezone="Asia/Singapore"):
tz = pytz.timezone(timezone)
utc_now = datetime.datetime.now(datetime.timezone.utc)
local_time = utc_now.astimezone(tz)
return local_time.strftime("%m%d_%H%M")
return local_time.strftime("%Y%m%d_%H%M%S")


def ignore_constructor(loader, node):
Expand Down
Loading

0 comments on commit e20d5d6

Please sign in to comment.