diff --git a/genienlp/metrics.py b/genienlp/metrics.py index a186c8b3..17fba9bb 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -287,7 +287,7 @@ def compute_metrics( requested_metrics: contains a subset of the following metrics em (exact match) sm (structure match): valid if the output is ThingTalk code. Whether the gold answer and prediction are identical if we ignore parameter values of ThingTalk programs - #TODO add all + # TODO add all lang: the language of the predictions and answers. Used for BERTScore. args: arguments example_ids: used to calculate some of e2e dialogue metrics that need to know span of each dialogue such as JGA diff --git a/genienlp/models/base.py b/genienlp/models/base.py index 2f4ee597..d97f662f 100644 --- a/genienlp/models/base.py +++ b/genienlp/models/base.py @@ -99,6 +99,33 @@ def set_generation_output_options(self, tasks): # TransformerSeq2Seq and TransformerLSTM will inherit from this model class GenieModelForGeneration(GenieModel): def validate( + self, + data_iterator, + task, + eval_dir=None, + output_predictions_only=False, + output_confidence_features=False, + original_order=None, + confidence_estimators=None, + disable_progbar=True, + **kwargs, + ): + if self.args.e2e_dialogue_evaluation: + return self.validate_e2e_dialogues( + data_iterator, task, eval_dir, output_predictions_only, original_order, disable_progbar + ) + else: + return self.validate_batch( + data_iterator, + task, + output_predictions_only, + output_confidence_features, + original_order, + confidence_estimators, + disable_progbar, + ) + + def validate_batch( self, data_iterator, task, diff --git a/genienlp/predict.py b/genienlp/predict.py index c372770b..5d8e7254 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -62,7 +62,6 @@ set_seed, split_folder_on_disk, ) -from .validate import generate_with_model logger = logging.getLogger(__name__) @@ -472,16 +471,14 @@ def run(args, device): confidence_estimators = None with torch.no_grad(), torch.cuda.amp.autocast(enabled=args.mixed_precision): - generation_output = generate_with_model( - model, + generation_output = model.validate( it, task, - args, - original_order=original_order, + eval_dir=eval_dir, output_confidence_features=args.save_confidence_features, + original_order=original_order, confidence_estimators=confidence_estimators, disable_progbar=False, - eval_dir=eval_dir, ) if args.save_confidence_features: diff --git a/genienlp/server.py b/genienlp/server.py index a348a0e7..d1c679fc 100644 --- a/genienlp/server.py +++ b/genienlp/server.py @@ -46,7 +46,6 @@ from .ned.ned_utils import init_ned_model from .tasks.registry import get_tasks from .util import adjust_language_code, get_devices, load_config_json, log_model_size, set_seed -from .validate import generate_with_model logger = logging.getLogger(__name__) @@ -213,11 +212,9 @@ def _numericalize_request(self, request, task, args): def _predict_batch(self, batch, task, args): if args.calibrator_paths is not None: - output = generate_with_model( - self.model, + output = self.model.validate( [batch], task, - args, output_predictions_only=True, confidence_estimators=self.confidence_estimators, ) @@ -238,7 +235,11 @@ def _predict_batch(self, batch, task, args): instance['score'][self.estimator_filenames[e_idx]] = float(estimator_scores[idx]) response.append(instance) else: - output = generate_with_model(self.model, [batch], task, args, output_predictions_only=True) + output = self.model.validate( + [batch], + task, + output_predictions_only=True, + ) if sum(args.num_outputs) > 1: response = [] for idx, predictions in enumerate(output.predictions): diff --git a/genienlp/train.py b/genienlp/train.py index c71ac23b..2fbfe9f1 100644 --- a/genienlp/train.py +++ b/genienlp/train.py @@ -43,6 +43,7 @@ from . import arguments, models from .arguments import save_args +from .metrics import calculate_and_reduce_metrics from .model_utils.optimizer import init_opt from .model_utils.parallel_utils import NamedTupleCompatibleDataParallel from .model_utils.saver import Saver @@ -54,9 +55,9 @@ log_model_size, make_data_loader, ned_dump_entity_type_pairs, + print_results, set_seed, ) -from .validate import print_results, validate def initialize_logger(args): @@ -221,6 +222,31 @@ def should_log(iteration, log_every): return iteration % log_every == 0 +def validate(task, val_iter, model, args, num_print=10): + with torch.no_grad(): + model.eval() + if isinstance(model, torch.nn.DataParallel): + # get rid of the DataParallel wrapper + model = model.module + + generation_output = model.validate(val_iter, task) + + # loss is already calculated + metrics_to_return = [metric for metric in task.metrics if metric != 'loss'] + + metrics = calculate_and_reduce_metrics(args, generation_output, metrics_to_return, model.tgt_lang) + + results = { + 'model prediction': generation_output.predictions, + 'gold answer': generation_output.answers, + 'context': generation_output.contexts, + } + + print_results(results, num_print) + + return generation_output, metrics + + def do_validate(iteration, args, model, val_iters, *, train_task, round_progress, task_progress, writer, logger): deca_score = 0 for val_task_idx, (val_task, val_iter) in enumerate(val_iters): diff --git a/genienlp/util.py b/genienlp/util.py index bce36453..551ae0eb 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -34,6 +34,7 @@ import random import re import shutil +import sys import time from json.decoder import JSONDecodeError from typing import List, Optional @@ -1013,3 +1014,33 @@ def replace_capturing_group(input, re_pattern, replacement): else: new_input = input return new_input + + +def print_results(results, num_print): + print() + + values = list(results.values()) + num_examples = len(values[0]) + + # examples are sorted by length + # to get good diversity, get half of examples from second quartile + start = int(num_examples / 4) + end = start + int(num_print / 2) + first_list = [val[start:end] for val in values] + + # and the other half from fourth quartile + start = int(3 * num_examples / 4) + end = start + num_print - int(num_print / 2) + second_list = [val[start:end] for val in values] + + # join examples + processed_values = [first + second for first, second in zip(first_list, second_list)] + + for ex_idx in range(len(processed_values[0])): + for key_idx, key in enumerate(results.keys()): + value = processed_values[key_idx][ex_idx] + v = value[0] if isinstance(value, list) else value + key_width = max(len(key) for key in results) + print(f'{key:>{key_width}}: {repr(v)}') + print() + sys.stdout.flush() diff --git a/genienlp/validate.py b/genienlp/validate.py deleted file mode 100644 index 1ced3b73..00000000 --- a/genienlp/validate.py +++ /dev/null @@ -1,125 +0,0 @@ -# -# Copyright (c) 2018, Salesforce, Inc. -# The Board of Trustees of the Leland Stanford Junior University -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import logging -import sys - -import torch - -from .metrics import calculate_and_reduce_metrics - -logger = logging.getLogger(__name__) - - -def generate_with_model( - model, - data_iterator, - task, - args, - output_predictions_only=False, - output_confidence_features=False, - original_order=None, - confidence_estimators=None, - disable_progbar=True, - eval_dir=None, -): - if args.e2e_dialogue_evaluation: - return model.validate_e2e_dialogues( - data_iterator, - task, - eval_dir=eval_dir, - output_predictions_only=output_predictions_only, - original_order=original_order, - disable_progbar=disable_progbar, - ) - else: - return model.validate( - data_iterator, - task, - output_predictions_only=output_predictions_only, - output_confidence_features=output_confidence_features, - original_order=original_order, - confidence_estimators=confidence_estimators, - disable_progbar=disable_progbar, - ) - - -def print_results(results, num_print): - print() - - values = list(results.values()) - num_examples = len(values[0]) - - # examples are sorted by length - # to get good diversity, get half of examples from second quartile - start = int(num_examples / 4) - end = start + int(num_print / 2) - first_list = [val[start:end] for val in values] - - # and the other half from fourth quartile - start = int(3 * num_examples / 4) - end = start + num_print - int(num_print / 2) - second_list = [val[start:end] for val in values] - - # join examples - processed_values = [first + second for first, second in zip(first_list, second_list)] - - for ex_idx in range(len(processed_values[0])): - for key_idx, key in enumerate(results.keys()): - value = processed_values[key_idx][ex_idx] - v = value[0] if isinstance(value, list) else value - key_width = max(len(key) for key in results) - print(f'{key:>{key_width}}: {repr(v)}') - print() - sys.stdout.flush() - - -def validate(task, val_iter, model, args, num_print=10): - with torch.no_grad(): - model.eval() - if isinstance(model, torch.nn.DataParallel): - # get rid of the DataParallel wrapper - model = model.module - - generation_output = generate_with_model(model, val_iter, task, args) - - # loss is already calculated - metrics_to_return = [metric for metric in task.metrics if metric != 'loss'] - - metrics = calculate_and_reduce_metrics(args, generation_output, metrics_to_return, model.tgt_lang) - - results = { - 'model prediction': generation_output.predictions, - 'gold answer': generation_output.answers, - 'context': generation_output.contexts, - } - - print_results(results, num_print) - - return generation_output, metrics