Skip to content

Commit

Permalink
Delete redundant print, and lint again to reformat (EvolvingLMMs-Lab#126
Browse files Browse the repository at this point in the history
)

* Resolve conflict when merge the kr_ego with internal_main_dev

* fix the bug of file overwrite

* Optimize the inference of videochatgpt dataset

* Resolve conflict

* delete repeated line

* reformat the code

* rename the file name for inference results

* group the same task together for cvrr and videochatgpt

* group the same task together for videochatgpt and cvrr

* reformat the code

* fix the bug of videochatgpt_consistency multiocessing

* Rename the metric from submission to subtask

* fix the bug of consistency where different answers agre generated in pred2

* add accuracy into the evaluation of cvrr

* add accuracy metric to cvrr dataset

* remove duplicate rows when merging from main branch

* Refactor videochatgpt_gen and videochatgpt_temporal for correct score parsing

* enable the webm video loader for llavavid as required in cvrr dataset

* Refactor process_results function to handle full_docs in videochatgpt task

* add tqdm to consistency gpt_eval

* Refactor the cvrr for correct aggregate logic

* change backend to decord for videochatgpt eval

* Fix for mkv video path

* add perceptiontest dataset test split

* doublecheck and optimize the code in egoschema

* rename metric name of perceptiontest

* add perceptiontest_validation dataset

* remove egoschema aggregate function name

* add temcompass mc dataset

* remove redundant files

* add tempcompass yes_no, captioning, caption_matching subsets

* add all the 5 aspects as metrics

* reformat the output dict for successful match

* remove redundant aggregation function in videochatgpt and rename some function names

* remove redundant aggregation function in activitynetqa and video_detail_description

* remove redundant aggregate functions in cvrr

* remove redundant rows in perception test

* use black ./ to reformat code

* debug: load webm file is now successful

* Remove perceptiontest and perceptiontest_val default template YAML files

* put post prompt in yaml for empcompass dataset

* align gpt eval model name for cvrr and debug the tempcompass in case match is unsuccessful

* "debug tempcompass captioning for multi-choice, and optimze matching rule in egoschema"

* "add a period to each option in egoschema becuase mplugowl always end with period"

* "add readme for egoschema"

* "change readme citation for egoschema"

* "delete redundant print, lint the repo"

---------

Co-authored-by: Bo Li <[email protected]>
Co-authored-by: kcz358 <[email protected]>
  • Loading branch information
3 people authored Jun 12, 2024
1 parent 37f39b5 commit 6cf1ded
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 25 deletions.
4 changes: 1 addition & 3 deletions lmms_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
if args.predict_only:
args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path:
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
raise ValueError("Specify --output_path if providing --log_samples or --predict_only")
if args.limit:
eval_logger.warning(" --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if args.include_path is not None:
Expand Down
3 changes: 3 additions & 0 deletions lmms_eval/api/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
def bypass_agg(arr):
return 999


@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
Expand Down Expand Up @@ -229,6 +230,7 @@ def sample_stddev(arr):
def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))


@register_metric(
metric="bypass",
higher_is_better=True,
Expand All @@ -238,6 +240,7 @@ def mean_stderr(arr):
def bypass(items):
return items


@register_metric(
metric="mcc",
higher_is_better=True,
Expand Down
5 changes: 2 additions & 3 deletions lmms_eval/api/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
else:
eval_logger.warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
)
eval_logger.warning(f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library...")

try:
metric_object = hf_evaluate.load(name)
Expand All @@ -123,6 +121,7 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
)


def register_aggregation(name):
def decorate(fn):
assert name not in AGGREGATION_REGISTRY, f"aggregation named '{name}' conflicts with existing registered aggregation!"
Expand Down
6 changes: 2 additions & 4 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def dump_config(self) -> dict:
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
# (num_fewshot)
return self.config.to_dict()

def override_metric(self, metric_name: str) -> None:
"""
Override the default metrics used for evaluation with custom metrics.
Expand All @@ -529,9 +529,7 @@ def override_metric(self, metric_name: str) -> None:
self._metric_fn_kwargs[metric_name] = {}
if not isinstance(self, ConfigurableTask):
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
self.aggregation = lambda: {
metric_name: get_metric_aggregation(metric_name)
}
self.aggregation = lambda: {metric_name: get_metric_aggregation(metric_name)}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)

Expand Down
6 changes: 2 additions & 4 deletions lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,10 @@ def simple_evaluate(
config = task_obj._config
if config["output_type"] == "generate_until" and gen_kwargs:
config["generation_kwargs"].update(gen_kwargs)

if predict_only:
log_samples = True
eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
)
eval_logger.info(f"Processing {task_name} in output-only mode. Metrics will not be calculated!")
# we have to change the class properties post-hoc. This is pretty hacky.
task_obj.override_metric(metric_name="bypass")

Expand Down
2 changes: 0 additions & 2 deletions lmms_eval/tasks/perceptiontest/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def perceptiontest_doc_to_text(doc, model_specific_prompt_kwargs=None):
question += "\n" + "C. " + op
index += 1
post_prompt = "\nAnswer with the option's letter from the given choices directly."
print("question\n")
print(question)

return f"{pre_prompt}{question}{post_prompt}"

Expand Down
10 changes: 6 additions & 4 deletions lmms_eval/tasks/worldqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,23 @@ def worldqa_process_results(doc, result):
"gpt_eval": {"pred": pred, "question_idx": doc["question_idx"], "object_description": doc["object_description"], "answer": doc["answer"], "eval_answer": eval_answer, "gpt_prompt": content},
}


def worldqa_process_results_mc(doc, result):
pred = result[0]
pred = result[0]
data = {
"gpt_eval": {"pred": pred, "question_idx": doc["question_idx"], "object_description": doc["object_description"], "answer": doc["answer"], "option" : doc["option"], "question" : doc["question"] },
}
"gpt_eval": {"pred": pred, "question_idx": doc["question_idx"], "object_description": doc["object_description"], "answer": doc["answer"], "option": doc["option"], "question": doc["question"]},
}
return data


def worldqa_aggregate_mc_eval(results):
score = 0
evaluator = WorldQA_MC_Evaluator(API_KEY=API_KEY, API_URL=API_URL)
for result in results:
score += evaluator.evaluate(result)
return score / len(results)


def worldqa_aggregate_submissions(results, args, task):
now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
submission_file_name = f"worldqa-{task}-{now_date_time}.json"
Expand Down Expand Up @@ -226,7 +229,6 @@ def worldqa_aggregate_mc_ppl(results, args):
worldqa_aggregate_submissions(results, args, "MC_PPL")



def worldqa_doc_to_choice(doc):
return [op.split(".")[1].strip() for op in doc["option"]]

Expand Down
10 changes: 5 additions & 5 deletions lmms_eval/tasks/worldqa/worldqa_mc_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

eval_logger = logging.getLogger("lmms-eval")


class WorldQA_MC_Evaluator:
def __init__(self, sys_prompt="There are several options:", API_KEY="", API_URL="", model_version="gpt-3.5-turbo-0613"):
self.sys_prompt = sys_prompt
self.model_version = model_version
self.API_KEY = API_KEY
self.API_URL = API_URL

def build_prompt(self, question, options, prediction):
tmpl = (
"You are an AI assistant who will help me to match an answer "
Expand All @@ -41,7 +42,7 @@ def build_prompt(self, question, options, prediction):
"Question: {}?\nOptions: {}\nAnswer: {}\nYour output: "
)
return tmpl.format(question, options, prediction)

# Prefetch Answers
def can_infer_option(self, answer, num_choice=5):
choices = string.ascii_uppercase[:num_choice]
Expand Down Expand Up @@ -70,7 +71,6 @@ def count(splits, choices="ABCD", prefix="", suffix=""):
if tup[0] + ch + tup[1] in splits:
return ch
return False


def _post_request(self, payload):
headers = {
Expand Down Expand Up @@ -108,11 +108,11 @@ def get_chat_response(self, prompt, temperature=0, max_tokens=256, n=1, patience
return "Failed to obtain answer via API"

def evaluate(self, results):
answer = results["answer"].split(".")[0]
answer = results["answer"].split(".")[0]
if self.can_infer_option(results["pred"], num_choice=4):
choice = self.can_infer_option(results["pred"], num_choice=4)
return int(choice.lower().strip() == answer.lower().strip())
else:
prompt = self.build_prompt(question=results["question"], options="\n".join(results["option"]), prediction=results["pred"])
prediction = self.get_chat_response(prompt)
return int(prediction.lower().strip() == answer.lower().strip())
return int(prediction.lower().strip() == answer.lower().strip())

0 comments on commit 6cf1ded

Please sign in to comment.