Skip to content

Commit

Permalink
reformatted run_wikitext-2_benchmark.py
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed Jan 3, 2024
1 parent 32aa17f commit 9bb5909
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed


def compute_file_ppl(output_filename):
with open(output_filename, "r") as f:
run_results = json.load(f)
nlls = []
lengths = []
for i, _res in enumerate(run_results['scored_results']):
for i, _res in enumerate(run_results["scored_results"]):
print(_res)
nlls.append(_res[0])
lengths.append(_res[1])
file_ppl = np.exp(-np.sum(nlls)/np.sum(lengths))
file_ppl = np.exp(-np.sum(nlls) / np.sum(lengths))
print("wikitext-2 ppl: %.4f" % file_ppl)



def evaluate(opt):
ArgumentParser.validate_translate_opts(opt)
ArgumentParser._get_all_transform_translate(opt)
Expand All @@ -32,9 +34,11 @@ def evaluate(opt):
dir_name = os.path.dirname(opt.models[0])
base_name = os.path.basename(opt.models[0])

output_filename = os.path.join(dir_name, "wikitext-2_benchmark_%s.json" % base_name[:-3])
output_filename = os.path.join(
dir_name, "wikitext-2_benchmark_%s.json" % base_name[:-3]
)

opt.src = 'wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw'
opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"

# Build the translator (along with the model)
engine = InferenceEnginePY(opt)
Expand All @@ -45,7 +49,7 @@ def evaluate(opt):
print(scored_results)
engine.terminate()

run_results['scored_results'] = scored_results
run_results["scored_results"] = scored_results

with open(output_filename, "w") as f:
json.dump(run_results, f, ensure_ascii=False, indent=2)
Expand All @@ -55,6 +59,7 @@ def evaluate(opt):
end_time = time.time()
logger.info("total run time %.2f" % (end_time - start_time))


def _get_parser():
parser = ArgumentParser(description="run_wikitext-2_benchmark.py")
opts.config_opts(parser)
Expand All @@ -70,6 +75,3 @@ def main():

if __name__ == "__main__":
main()

# python3 score_file.py -inference_config_file llama2_inference.yaml
# -file data/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw.10

0 comments on commit 9bb5909

Please sign in to comment.