From e7db2543270e343d0f248335e21b08fd24ec845a Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 25 Feb 2022 14:12:23 -0800 Subject: [PATCH] predict: remove code for file evaluation --- genienlp/predict.py | 108 ++++++++++++-------------------------------- 1 file changed, 30 insertions(+), 78 deletions(-) diff --git a/genienlp/predict.py b/genienlp/predict.py index 2471ca4f..c92a8213 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -43,7 +43,6 @@ except RuntimeError: pass -import sys import torch @@ -63,17 +62,17 @@ set_seed, split_folder_on_disk, ) -from .validate import GenerationOutput, generate_with_model +from .validate import generate_with_model logger = logging.getLogger(__name__) def parse_argv(parser): - parser.add_argument('--path', type=str, required='--pred_file' not in sys.argv, help='Folder to load the model from') + parser.add_argument('--path', type=str, required=True, help='Folder to load the model from') parser.add_argument( '--evaluate', type=str, - required='--pred_file' not in sys.argv, + required=True, choices=['train', 'valid', 'test'], help='Which dataset to do predictions for (train, dev or test)', ) @@ -106,12 +105,6 @@ def parse_argv(parser): parser.add_argument('--cache', default='.cache', type=str, help='where to save cached files') parser.add_argument('--subsample', default=20000000, type=int, help='subsample the eval/test datasets') - parser.add_argument( - '--pred_file', - type=str, - help='If provided, we just compute evaluation metrics on this file and bypass model prediction. File should be in tsv format with id, pred, answer columns', - ) - parser.add_argument( '--pred_languages', type=str, @@ -564,20 +557,39 @@ def run(args, device): prediction_file.write('\n'.join(lines) + '\n') if len(generation_output.answers) > 0: - compute_metrics_on_file( - task_scores, - prediction_file_name, - results_file_name, - task, + metrics_to_compute = task.metrics + metrics_to_compute += args.extra_metrics + metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']] + if args.main_metric_only: + metrics_to_compute = [metrics_to_compute[0]] + metrics = calculate_and_reduce_metrics( + generation_output, + metrics_to_compute, args, tgt_lang, - confidence_scores=generation_output.confidence_scores, ) - log_final_results(args, task_scores) + with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file: + results_file.write(json.dumps(metrics) + '\n') + + if not args.silent: + for i, (c, p, a) in enumerate( + zip(generation_output.contexts, generation_output.predictions, generation_output.answers) + ): + log_string = ( + f'\nContext {i + 1}: {c}\nPrediction {i + 1} ({len(p)} outputs): {p}\nAnswer {i + 1}: {a}\n' + ) + if args.calibrator_paths is not None: + log_string += f'Confidence {i + 1} : ' + for score in generation_output.confidence_scores: + log_string += f'{score[i]:.3f}, ' + log_string += '\n' + logger.info(log_string) + logger.info(metrics) + + task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]])) -def log_final_results(args, task_scores): decaScore = [] for task in task_scores.keys(): decaScore.append( @@ -592,55 +604,6 @@ def log_final_results(args, task_scores): logger.info(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n') -def compute_metrics_on_file(task_scores, pred_file, results_file_name, task, args, tgt_lang, confidence_scores=None): - generation_output = GenerationOutput() - ids, contexts, preds, targets = [], [], [], [] - with open(pred_file) as fin: - for line in fin: - id_, *pred, target, context = line.strip('\n').split('\t') - ids.append(id_) - contexts.append(context) - preds.append(pred) - targets.append(target) - - generation_output.example_ids = ids - generation_output.contexts = contexts - generation_output.predictions = preds - generation_output.answers = targets - generation_output.confidence_scores = confidence_scores - - metrics_to_compute = task.metrics - metrics_to_compute += args.extra_metrics - metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']] - if args.main_metric_only: - metrics_to_compute = [metrics_to_compute[0]] - metrics = calculate_and_reduce_metrics( - generation_output, - metrics_to_compute, - args, - tgt_lang, - ) - - with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file: - results_file.write(json.dumps(metrics) + '\n') - - if not args.silent: - for i, (c, p, a) in enumerate( - zip(generation_output.contexts, generation_output.predictions, generation_output.answers) - ): - log_string = f'\nContext {i + 1}: {c}\nPrediction {i + 1} ({len(p)} outputs): {p}\nAnswer {i + 1}: {a}\n' - if args.calibrator_paths is not None: - log_string += f'Confidence {i + 1} : ' - for score in generation_output.confidence_scores: - log_string += f'{score[i]:.3f}, ' - log_string += '\n' - logger.info(log_string) - - logger.info(metrics) - - task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]])) - - def main(args): load_config_json(args) check_and_update_generation_args(args) @@ -667,17 +630,6 @@ def main(args): task.metrics = new_metrics - if args.pred_file and os.path.exists(args.pred_file): - task_scores = defaultdict(list) - eval_dir = os.path.join(args.eval_dir, args.evaluate) - os.makedirs(eval_dir, exist_ok=True) - tgt_lang = args.pred_tgt_languages[0] - for task in args.tasks: - results_file_name = os.path.join(eval_dir, task.name + '.results.json') - compute_metrics_on_file(task_scores, args.pred_file, results_file_name, task, args, tgt_lang) - log_final_results(args, task_scores) - return - logger.info(f'Loading from {args.best_checkpoint}') devices = get_devices(args.devices)