Skip to content

Commit

Permalink
generation_output --> validation_output
Browse files Browse the repository at this point in the history
  • Loading branch information
Mehrad0711 committed Mar 2, 2022
1 parent e2cad6b commit 91c3031
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 61 deletions.
8 changes: 4 additions & 4 deletions genienlp/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,11 @@ def compute_metrics(
return metric_dict


def calculate_and_reduce_metrics(args, generation_output, metrics_to_compute, lang):
def calculate_and_reduce_metrics(args, validation_output, metrics_to_compute, lang):
metrics = OrderedDict()
example_ids = generation_output.example_ids
predictions = generation_output.predictions
answers = generation_output.answers
example_ids = validation_output.example_ids
predictions = validation_output.predictions
answers = validation_output.answers

if args.reduce_metrics == 'max':
for i in range(len(predictions[0])): # for each output (in case of multiple outputs)
Expand Down
35 changes: 31 additions & 4 deletions genienlp/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import logging
import os
from collections import defaultdict
from typing import List, Optional

import torch
import ujson
Expand All @@ -41,7 +42,7 @@
from ..data_utils.example import NumericalizedExamples, SequentialField
from ..data_utils.numericalizer import TransformerNumericalizer
from ..data_utils.progbar import progress_bar
from ..util import GenerationOutput, adjust_language_code, merge_translated_sentences, replace_capturing_group
from ..util import adjust_language_code, merge_translated_sentences, replace_capturing_group

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,6 +97,32 @@ def set_generation_output_options(self, tasks):
self._output_hidden_states = False


class ValidationOutput(object):
"""
Contains all the information that model's validate() method may output
"""

def __init__(
self,
loss: Optional[float] = None,
example_ids: Optional[List] = None,
predictions: Optional[List] = None,
raw_predictions: Optional[List] = None,
answers: Optional[List] = None,
contexts: Optional[List] = None,
confidence_features: Optional[List] = None,
confidence_scores: Optional[List] = None,
):
self.loss = loss
self.example_ids = example_ids
self.predictions = predictions
self.raw_predictions = raw_predictions
self.answers = answers
self.contexts = contexts
self.confidence_features = confidence_features
self.confidence_scores = confidence_scores


# TransformerSeq2Seq and TransformerLSTM will inherit from this model
class GenieModelForGeneration(GenieModel):
def validate(
Expand Down Expand Up @@ -324,7 +351,7 @@ def get_example_index(i):
self.numericalizer._tokenizer.tgt_lang,
)

output = GenerationOutput(loss=total_loss)
output = ValidationOutput(loss=total_loss)

if output_predictions_only:
output.predictions = predictions
Expand Down Expand Up @@ -576,7 +603,7 @@ def validate_e2e_dialogues(

# TODO calculate and return loss
loss = None
output = GenerationOutput(loss=loss)
output = ValidationOutput(loss=loss)

if output_predictions_only:
output.predictions = predictions
Expand Down Expand Up @@ -696,7 +723,7 @@ def validate(self, data_iterator, task, original_order=None, disable_progbar=Tru
)
]

output = GenerationOutput(
output = ValidationOutput(
loss=total_loss,
example_ids=all_example_ids,
contexts=all_contexts,
Expand Down
40 changes: 20 additions & 20 deletions genienlp/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,16 +379,16 @@ def prepare_data_iterators(args, val_sets, numericalizer, device):
return iters


def create_output_lines(args, index, generation_output):
predictions = generation_output.raw_predictions if args.translate_return_raw_outputs else generation_output.predictions
def create_output_lines(args, index, validation_output):
predictions = validation_output.raw_predictions if args.translate_return_raw_outputs else validation_output.predictions
if args.one_output_per_line:
lines = [
'\t'.join(
[
generation_output.example_ids[index],
validation_output.example_ids[index],
prediction,
generation_output.answers[index],
generation_output.contexts[index],
validation_output.answers[index],
validation_output.contexts[index],
]
)
for prediction in predictions[index]
Expand All @@ -397,15 +397,15 @@ def create_output_lines(args, index, generation_output):
lines = [
'\t'.join(
[
generation_output.example_ids[index],
validation_output.example_ids[index],
*predictions[index],
generation_output.answers[index],
generation_output.contexts[index],
validation_output.answers[index],
validation_output.contexts[index],
]
)
] # one line with all generation outputs separated by '\t'
if args.calibrator_paths is not None:
for score in generation_output.confidence_scores:
for score in validation_output.confidence_scores:
lines = [line + '\t' + str(score[index]) for line in lines] # append score to all lines
return lines

Expand Down Expand Up @@ -471,7 +471,7 @@ def run(args, device):
confidence_estimators = None

with torch.no_grad(), torch.cuda.amp.autocast(enabled=args.mixed_precision):
generation_output = model.validate(
validation_output = model.validate(
it,
task,
eval_dir=eval_dir,
Expand All @@ -482,45 +482,45 @@ def run(args, device):
)

if args.save_confidence_features:
torch.save(generation_output.confidence_features, args.confidence_feature_path)
torch.save(validation_output.confidence_features, args.confidence_feature_path)

# write into file
# TODO change to jsonl format
with open(prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file:
for i in range(len(generation_output.example_ids)):
lines = create_output_lines(args, i, generation_output)
for i in range(len(validation_output.example_ids)):
lines = create_output_lines(args, i, validation_output)
prediction_file.write('\n'.join(lines) + '\n')

if args.translate_return_raw_outputs:
with open(raw_prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file:
for i in range(len(generation_output.example_ids)):
lines = create_output_lines(args, i, generation_output)
for i in range(len(validation_output.example_ids)):
lines = create_output_lines(args, i, validation_output)
prediction_file.write('\n'.join(lines) + '\n')

if len(generation_output.answers) > 0:
if len(validation_output.answers) > 0:
metrics_to_compute = get_metrics_to_compute(args, task)
metrics = calculate_and_reduce_metrics(args, generation_output, metrics_to_compute, tgt_lang)
metrics = calculate_and_reduce_metrics(args, validation_output, metrics_to_compute, tgt_lang)

with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file:
results_file.write(json.dumps(metrics) + '\n')

if not args.silent:
for i, (c, p, a) in enumerate(
zip(generation_output.contexts, generation_output.predictions, generation_output.answers)
zip(validation_output.contexts, validation_output.predictions, validation_output.answers)
):
log_string = '\n'.join(
[f'Context {i + 1}: {c}', f'Prediction {i + 1} ({len(p)} outputs): {p}', f'Answer {i + 1}: {a}']
)
if args.calibrator_paths is not None:
log_string += f'Confidence {i + 1} : '
for score in generation_output.confidence_scores:
for score in validation_output.confidence_scores:
log_string += f'{score[i]:.3f}, '
log_string += '\n'
logger.info(log_string)

logger.info(metrics)

task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]]))
task_scores[task].append((len(validation_output.answers), metrics[task.metrics[0]]))

decaScore = []
for task in task_scores.keys():
Expand Down
12 changes: 6 additions & 6 deletions genienlp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,22 @@ def validate_while_training(task, val_iter, model, args, num_print=10):
# get rid of the DataParallel wrapper
model = model.module

generation_output = model.validate(val_iter, task)
validation_output = model.validate(val_iter, task)

# loss is already calculated
metrics_to_return = [metric for metric in task.metrics if metric != 'loss']

metrics = calculate_and_reduce_metrics(args, generation_output, metrics_to_return, model.tgt_lang)
metrics = calculate_and_reduce_metrics(args, validation_output, metrics_to_return, model.tgt_lang)

results = {
'model prediction': generation_output.predictions,
'gold answer': generation_output.answers,
'context': generation_output.contexts,
'model prediction': validation_output.predictions,
'gold answer': validation_output.answers,
'context': validation_output.contexts,
}

print_results(results, num_print)

return generation_output, metrics
return validation_output, metrics


def do_validate_while_training(
Expand Down
28 changes: 1 addition & 27 deletions genienlp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import sys
import time
from json.decoder import JSONDecodeError
from typing import List, Optional
from typing import List

import numpy as np
import torch
Expand Down Expand Up @@ -233,32 +233,6 @@ def __repr__(self) -> str:
)


class GenerationOutput(object):
"""
Contains all the information that the generation function may need to output
"""

def __init__(
self,
loss: Optional[float] = None,
example_ids: Optional[List] = None,
predictions: Optional[List] = None,
raw_predictions: Optional[List] = None,
answers: Optional[List] = None,
contexts: Optional[List] = None,
confidence_features: Optional[List] = None,
confidence_scores: Optional[List] = None,
):
self.loss = loss
self.example_ids = example_ids
self.predictions = predictions
self.raw_predictions = raw_predictions
self.answers = answers
self.contexts = contexts
self.confidence_features = confidence_features
self.confidence_scores = confidence_scores


def remove_thingtalk_quotes(thingtalk):
quote_values = []
while True:
Expand Down

0 comments on commit 91c3031

Please sign in to comment.