From 91c30319ad125b90f55ab62aba0bc10d19cd8eea Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 2 Mar 2022 12:52:49 -0800 Subject: [PATCH] generation_output --> validation_output --- genienlp/metrics.py | 8 ++++---- genienlp/models/base.py | 35 +++++++++++++++++++++++++++++++---- genienlp/predict.py | 40 ++++++++++++++++++++-------------------- genienlp/train.py | 12 ++++++------ genienlp/util.py | 28 +--------------------------- 5 files changed, 62 insertions(+), 61 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 17fba9bb..b8680e78 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -378,11 +378,11 @@ def compute_metrics( return metric_dict -def calculate_and_reduce_metrics(args, generation_output, metrics_to_compute, lang): +def calculate_and_reduce_metrics(args, validation_output, metrics_to_compute, lang): metrics = OrderedDict() - example_ids = generation_output.example_ids - predictions = generation_output.predictions - answers = generation_output.answers + example_ids = validation_output.example_ids + predictions = validation_output.predictions + answers = validation_output.answers if args.reduce_metrics == 'max': for i in range(len(predictions[0])): # for each output (in case of multiple outputs) diff --git a/genienlp/models/base.py b/genienlp/models/base.py index d97f662f..e91c0e16 100644 --- a/genienlp/models/base.py +++ b/genienlp/models/base.py @@ -31,6 +31,7 @@ import logging import os from collections import defaultdict +from typing import List, Optional import torch import ujson @@ -41,7 +42,7 @@ from ..data_utils.example import NumericalizedExamples, SequentialField from ..data_utils.numericalizer import TransformerNumericalizer from ..data_utils.progbar import progress_bar -from ..util import GenerationOutput, adjust_language_code, merge_translated_sentences, replace_capturing_group +from ..util import adjust_language_code, merge_translated_sentences, replace_capturing_group logger = logging.getLogger(__name__) @@ -96,6 +97,32 @@ def set_generation_output_options(self, tasks): self._output_hidden_states = False +class ValidationOutput(object): + """ + Contains all the information that model's validate() method may output + """ + + def __init__( + self, + loss: Optional[float] = None, + example_ids: Optional[List] = None, + predictions: Optional[List] = None, + raw_predictions: Optional[List] = None, + answers: Optional[List] = None, + contexts: Optional[List] = None, + confidence_features: Optional[List] = None, + confidence_scores: Optional[List] = None, + ): + self.loss = loss + self.example_ids = example_ids + self.predictions = predictions + self.raw_predictions = raw_predictions + self.answers = answers + self.contexts = contexts + self.confidence_features = confidence_features + self.confidence_scores = confidence_scores + + # TransformerSeq2Seq and TransformerLSTM will inherit from this model class GenieModelForGeneration(GenieModel): def validate( @@ -324,7 +351,7 @@ def get_example_index(i): self.numericalizer._tokenizer.tgt_lang, ) - output = GenerationOutput(loss=total_loss) + output = ValidationOutput(loss=total_loss) if output_predictions_only: output.predictions = predictions @@ -576,7 +603,7 @@ def validate_e2e_dialogues( # TODO calculate and return loss loss = None - output = GenerationOutput(loss=loss) + output = ValidationOutput(loss=loss) if output_predictions_only: output.predictions = predictions @@ -696,7 +723,7 @@ def validate(self, data_iterator, task, original_order=None, disable_progbar=Tru ) ] - output = GenerationOutput( + output = ValidationOutput( loss=total_loss, example_ids=all_example_ids, contexts=all_contexts, diff --git a/genienlp/predict.py b/genienlp/predict.py index 5d8e7254..0793da34 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -379,16 +379,16 @@ def prepare_data_iterators(args, val_sets, numericalizer, device): return iters -def create_output_lines(args, index, generation_output): - predictions = generation_output.raw_predictions if args.translate_return_raw_outputs else generation_output.predictions +def create_output_lines(args, index, validation_output): + predictions = validation_output.raw_predictions if args.translate_return_raw_outputs else validation_output.predictions if args.one_output_per_line: lines = [ '\t'.join( [ - generation_output.example_ids[index], + validation_output.example_ids[index], prediction, - generation_output.answers[index], - generation_output.contexts[index], + validation_output.answers[index], + validation_output.contexts[index], ] ) for prediction in predictions[index] @@ -397,15 +397,15 @@ def create_output_lines(args, index, generation_output): lines = [ '\t'.join( [ - generation_output.example_ids[index], + validation_output.example_ids[index], *predictions[index], - generation_output.answers[index], - generation_output.contexts[index], + validation_output.answers[index], + validation_output.contexts[index], ] ) ] # one line with all generation outputs separated by '\t' if args.calibrator_paths is not None: - for score in generation_output.confidence_scores: + for score in validation_output.confidence_scores: lines = [line + '\t' + str(score[index]) for line in lines] # append score to all lines return lines @@ -471,7 +471,7 @@ def run(args, device): confidence_estimators = None with torch.no_grad(), torch.cuda.amp.autocast(enabled=args.mixed_precision): - generation_output = model.validate( + validation_output = model.validate( it, task, eval_dir=eval_dir, @@ -482,45 +482,45 @@ def run(args, device): ) if args.save_confidence_features: - torch.save(generation_output.confidence_features, args.confidence_feature_path) + torch.save(validation_output.confidence_features, args.confidence_feature_path) # write into file # TODO change to jsonl format with open(prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file: - for i in range(len(generation_output.example_ids)): - lines = create_output_lines(args, i, generation_output) + for i in range(len(validation_output.example_ids)): + lines = create_output_lines(args, i, validation_output) prediction_file.write('\n'.join(lines) + '\n') if args.translate_return_raw_outputs: with open(raw_prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file: - for i in range(len(generation_output.example_ids)): - lines = create_output_lines(args, i, generation_output) + for i in range(len(validation_output.example_ids)): + lines = create_output_lines(args, i, validation_output) prediction_file.write('\n'.join(lines) + '\n') - if len(generation_output.answers) > 0: + if len(validation_output.answers) > 0: metrics_to_compute = get_metrics_to_compute(args, task) - metrics = calculate_and_reduce_metrics(args, generation_output, metrics_to_compute, tgt_lang) + metrics = calculate_and_reduce_metrics(args, validation_output, metrics_to_compute, tgt_lang) with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file: results_file.write(json.dumps(metrics) + '\n') if not args.silent: for i, (c, p, a) in enumerate( - zip(generation_output.contexts, generation_output.predictions, generation_output.answers) + zip(validation_output.contexts, validation_output.predictions, validation_output.answers) ): log_string = '\n'.join( [f'Context {i + 1}: {c}', f'Prediction {i + 1} ({len(p)} outputs): {p}', f'Answer {i + 1}: {a}'] ) if args.calibrator_paths is not None: log_string += f'Confidence {i + 1} : ' - for score in generation_output.confidence_scores: + for score in validation_output.confidence_scores: log_string += f'{score[i]:.3f}, ' log_string += '\n' logger.info(log_string) logger.info(metrics) - task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]])) + task_scores[task].append((len(validation_output.answers), metrics[task.metrics[0]])) decaScore = [] for task in task_scores.keys(): diff --git a/genienlp/train.py b/genienlp/train.py index a465f9b2..d0c6f5ac 100644 --- a/genienlp/train.py +++ b/genienlp/train.py @@ -229,22 +229,22 @@ def validate_while_training(task, val_iter, model, args, num_print=10): # get rid of the DataParallel wrapper model = model.module - generation_output = model.validate(val_iter, task) + validation_output = model.validate(val_iter, task) # loss is already calculated metrics_to_return = [metric for metric in task.metrics if metric != 'loss'] - metrics = calculate_and_reduce_metrics(args, generation_output, metrics_to_return, model.tgt_lang) + metrics = calculate_and_reduce_metrics(args, validation_output, metrics_to_return, model.tgt_lang) results = { - 'model prediction': generation_output.predictions, - 'gold answer': generation_output.answers, - 'context': generation_output.contexts, + 'model prediction': validation_output.predictions, + 'gold answer': validation_output.answers, + 'context': validation_output.contexts, } print_results(results, num_print) - return generation_output, metrics + return validation_output, metrics def do_validate_while_training( diff --git a/genienlp/util.py b/genienlp/util.py index 551ae0eb..9dafe694 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -37,7 +37,7 @@ import sys import time from json.decoder import JSONDecodeError -from typing import List, Optional +from typing import List import numpy as np import torch @@ -233,32 +233,6 @@ def __repr__(self) -> str: ) -class GenerationOutput(object): - """ - Contains all the information that the generation function may need to output - """ - - def __init__( - self, - loss: Optional[float] = None, - example_ids: Optional[List] = None, - predictions: Optional[List] = None, - raw_predictions: Optional[List] = None, - answers: Optional[List] = None, - contexts: Optional[List] = None, - confidence_features: Optional[List] = None, - confidence_scores: Optional[List] = None, - ): - self.loss = loss - self.example_ids = example_ids - self.predictions = predictions - self.raw_predictions = raw_predictions - self.answers = answers - self.contexts = contexts - self.confidence_features = confidence_features - self.confidence_scores = confidence_scores - - def remove_thingtalk_quotes(thingtalk): quote_values = [] while True: