From 9d5ad74aae449ff3c289b59718b66c50051f4807 Mon Sep 17 00:00:00 2001 From: kcz358 Date: Tue, 3 Sep 2024 13:51:23 +0000 Subject: [PATCH 1/2] Bring back anls --- lmms_eval/api/metrics.py | 48 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/lmms_eval/api/metrics.py b/lmms_eval/api/metrics.py index 48380c79..6de0d771 100755 --- a/lmms_eval/api/metrics.py +++ b/lmms_eval/api/metrics.py @@ -275,6 +275,52 @@ def bits_per_byte_fn(items): # This is a passthrough function return items +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] + + +@register_metric( + metric="anls", + higher_is_better=True, + output_type="generate_until", + aggregation="mean", +) +def anls( + references, + predictions, + thresh_hold=0.5, +): # This is a passthrough function + """https://github.com/QwenLM/Qwen-VL/blob/master/eval_mm/infographicsvqa_eval.py""" + values = [] + for answer in references: + # preprocess both the answers - gt and prediction + gt_answer = " ".join(answer.strip().lower().split()) + det_answer = " ".join(predictions[0].strip().lower().split()) + + # dist = levenshtein_distance(answer.lower(), detObject['answer'].lower()) + dist = levenshtein_distance(gt_answer, det_answer) + length = max(len(answer.upper()), len(predictions[0].upper())) + values.append(0.0 if length == 0 else float(dist) / float(length)) + + question_result = 1 - min(values) + + if question_result < thresh_hold: + question_result = 0 + return {"anls": question_result} + + def pop_stddev(arr): mu = mean(arr) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) @@ -296,7 +342,7 @@ def mean_stderr(arr): aggregation="bypass", ) def bypass(items): - return None + return items @register_metric( From 012696898e20190ee8d42666da15ce15e7240533 Mon Sep 17 00:00:00 2001 From: kcz358 Date: Tue, 3 Sep 2024 13:52:21 +0000 Subject: [PATCH 2/2] Remove not used txt writer --- lmms_eval/evaluator.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index 066fc0cd..91292370 100755 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -607,19 +607,6 @@ def evaluate( else: results_dict = None - with open(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt", "w") as f: - f.write(f"rank {int(os.environ.get('RANK', 0))} eval done") - while len([file for file in os.listdir(cli_args.output_path) if file.endswith("metric_eval_done.txt")]) < lm._world_size: - time.sleep(1) - - else: - return None - - with open(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt", "w") as f: - f.write(f"rank {int(os.environ.get('RANK', 0))} eval done") - while len([file for file in os.listdir(cli_args.output_path) if file.endswith("metric_eval_done.txt")]) < lm._world_size: - time.sleep(1) - lm.accelerator.wait_for_everyone() return results_dict