From a34ca3bf709645ede25d7ff711ea19e04c727c69 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 28 Feb 2022 11:59:39 -0800 Subject: [PATCH] predict: move loop outside of create_output_lines --- genienlp/predict.py | 63 ++++++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/genienlp/predict.py b/genienlp/predict.py index faeaf919e..4325c6f23 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -383,31 +383,34 @@ def prepare_data_iterators(args, val_sets, numericalizer, device): return iters -def create_output_line(args, generation_output): - lines = [] - for i in range(len(generation_output.example_ids)): - predictions = generation_output.raw_predictions if args.translate_return_raw_outputs else generation_output.predictions - if args.one_output_per_line: - lines = [ - '\t'.join( - [generation_output.example_ids[i], prediction, generation_output.answers[i], generation_output.contexts[i]] - ) - for prediction in predictions[i] - ] # one line per generation output - else: - lines = [ - '\t'.join( - [ - generation_output.example_ids[i], - *predictions[i], - generation_output.answers[i], - generation_output.contexts[i], - ] - ) - ] # one line with all generation outputs separated by '\t' - if args.calibrator_paths is not None: - for score in generation_output.confidence_scores: - lines = [line + '\t' + str(score[i]) for line in lines] # append score to all lines +def create_output_lines(args, index, generation_output): + predictions = generation_output.raw_predictions if args.translate_return_raw_outputs else generation_output.predictions + if args.one_output_per_line: + lines = [ + '\t'.join( + [ + generation_output.example_ids[index], + prediction, + generation_output.answers[index], + generation_output.contexts[index], + ] + ) + for prediction in predictions[index] + ] # one line per generation output + else: + lines = [ + '\t'.join( + [ + generation_output.example_ids[index], + *predictions[index], + generation_output.answers[index], + generation_output.contexts[index], + ] + ) + ] # one line with all generation outputs separated by '\t' + if args.calibrator_paths is not None: + for score in generation_output.confidence_scores: + lines = [line + '\t' + str(score[index]) for line in lines] # append score to all lines return lines @@ -490,13 +493,15 @@ def run(args, device): # write into file # TODO change to jsonl format with open(prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file: - lines = create_output_line(args, generation_output) - prediction_file.write('\n'.join(lines) + '\n') + for i in range(len(generation_output.example_ids)): + lines = create_output_lines(args, i, generation_output) + prediction_file.write('\n'.join(lines) + '\n') if args.translate_return_raw_outputs: with open(raw_prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file: - lines = create_output_line(args, generation_output) - prediction_file.write('\n'.join(lines) + '\n') + for i in range(len(generation_output.example_ids)): + lines = create_output_lines(args, i, generation_output) + prediction_file.write('\n'.join(lines) + '\n') if len(generation_output.answers) > 0: metrics_to_compute = get_metrics_to_compute(args, task)