From 9bb59099c9736db5deebe7be415499fbd2223b69 Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Wed, 3 Jan 2024 16:23:12 +0100 Subject: [PATCH] reformatted run_wikitext-2_benchmark.py --- .../WIKITEXT2/run_wikitext-2_benchmark.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py index 1a185ce4bb..02a1d0e0c4 100644 --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -8,18 +8,20 @@ from onmt.utils.parse import ArgumentParser from onmt.utils.misc import use_gpu, set_random_seed + def compute_file_ppl(output_filename): with open(output_filename, "r") as f: run_results = json.load(f) nlls = [] lengths = [] - for i, _res in enumerate(run_results['scored_results']): + for i, _res in enumerate(run_results["scored_results"]): print(_res) nlls.append(_res[0]) lengths.append(_res[1]) - file_ppl = np.exp(-np.sum(nlls)/np.sum(lengths)) + file_ppl = np.exp(-np.sum(nlls) / np.sum(lengths)) print("wikitext-2 ppl: %.4f" % file_ppl) - + + def evaluate(opt): ArgumentParser.validate_translate_opts(opt) ArgumentParser._get_all_transform_translate(opt) @@ -32,9 +34,11 @@ def evaluate(opt): dir_name = os.path.dirname(opt.models[0]) base_name = os.path.basename(opt.models[0]) - output_filename = os.path.join(dir_name, "wikitext-2_benchmark_%s.json" % base_name[:-3]) + output_filename = os.path.join( + dir_name, "wikitext-2_benchmark_%s.json" % base_name[:-3] + ) - opt.src = 'wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw' + opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw" # Build the translator (along with the model) engine = InferenceEnginePY(opt) @@ -45,7 +49,7 @@ def evaluate(opt): print(scored_results) engine.terminate() - run_results['scored_results'] = scored_results + run_results["scored_results"] = scored_results with open(output_filename, "w") as f: json.dump(run_results, f, ensure_ascii=False, indent=2) @@ -55,6 +59,7 @@ def evaluate(opt): end_time = time.time() logger.info("total run time %.2f" % (end_time - start_time)) + def _get_parser(): parser = ArgumentParser(description="run_wikitext-2_benchmark.py") opts.config_opts(parser) @@ -70,6 +75,3 @@ def main(): if __name__ == "__main__": main() - -# python3 score_file.py -inference_config_file llama2_inference.yaml - # -file data/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw.10 \ No newline at end of file