Skip to content

Commit

Permalink
Simplify validation code
Browse files Browse the repository at this point in the history
  • Loading branch information
Mehrad0711 committed Mar 2, 2022
1 parent 939f9d0 commit 8833a1b
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 138 deletions.
2 changes: 1 addition & 1 deletion genienlp/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def compute_metrics(
requested_metrics: contains a subset of the following metrics
em (exact match)
sm (structure match): valid if the output is ThingTalk code. Whether the gold answer and prediction are identical if we ignore parameter values of ThingTalk programs
#TODO add all
# TODO add all
lang: the language of the predictions and answers. Used for BERTScore.
args: arguments
example_ids: used to calculate some of e2e dialogue metrics that need to know span of each dialogue such as JGA
Expand Down
27 changes: 27 additions & 0 deletions genienlp/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,33 @@ def set_generation_output_options(self, tasks):
# TransformerSeq2Seq and TransformerLSTM will inherit from this model
class GenieModelForGeneration(GenieModel):
def validate(
self,
data_iterator,
task,
eval_dir=None,
output_predictions_only=False,
output_confidence_features=False,
original_order=None,
confidence_estimators=None,
disable_progbar=True,
**kwargs,
):
if self.args.e2e_dialogue_evaluation:
return self.validate_e2e_dialogues(
data_iterator, task, eval_dir, output_predictions_only, original_order, disable_progbar
)
else:
return self.validate_batch(
data_iterator,
task,
output_predictions_only,
output_confidence_features,
original_order,
confidence_estimators,
disable_progbar,
)

def validate_batch(
self,
data_iterator,
task,
Expand Down
9 changes: 3 additions & 6 deletions genienlp/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
set_seed,
split_folder_on_disk,
)
from .validate import generate_with_model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -472,16 +471,14 @@ def run(args, device):
confidence_estimators = None

with torch.no_grad(), torch.cuda.amp.autocast(enabled=args.mixed_precision):
generation_output = generate_with_model(
model,
generation_output = model.validate(
it,
task,
args,
original_order=original_order,
eval_dir=eval_dir,
output_confidence_features=args.save_confidence_features,
original_order=original_order,
confidence_estimators=confidence_estimators,
disable_progbar=False,
eval_dir=eval_dir,
)

if args.save_confidence_features:
Expand Down
11 changes: 6 additions & 5 deletions genienlp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from .ned.ned_utils import init_ned_model
from .tasks.registry import get_tasks
from .util import adjust_language_code, get_devices, load_config_json, log_model_size, set_seed
from .validate import generate_with_model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -213,11 +212,9 @@ def _numericalize_request(self, request, task, args):

def _predict_batch(self, batch, task, args):
if args.calibrator_paths is not None:
output = generate_with_model(
self.model,
output = self.model.validate(
[batch],
task,
args,
output_predictions_only=True,
confidence_estimators=self.confidence_estimators,
)
Expand All @@ -238,7 +235,11 @@ def _predict_batch(self, batch, task, args):
instance['score'][self.estimator_filenames[e_idx]] = float(estimator_scores[idx])
response.append(instance)
else:
output = generate_with_model(self.model, [batch], task, args, output_predictions_only=True)
output = self.model.validate(
[batch],
task,
output_predictions_only=True,
)
if sum(args.num_outputs) > 1:
response = []
for idx, predictions in enumerate(output.predictions):
Expand Down
28 changes: 27 additions & 1 deletion genienlp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

from . import arguments, models
from .arguments import save_args
from .metrics import calculate_and_reduce_metrics
from .model_utils.optimizer import init_opt
from .model_utils.parallel_utils import NamedTupleCompatibleDataParallel
from .model_utils.saver import Saver
Expand All @@ -54,9 +55,9 @@
log_model_size,
make_data_loader,
ned_dump_entity_type_pairs,
print_results,
set_seed,
)
from .validate import print_results, validate


def initialize_logger(args):
Expand Down Expand Up @@ -221,6 +222,31 @@ def should_log(iteration, log_every):
return iteration % log_every == 0


def validate(task, val_iter, model, args, num_print=10):
with torch.no_grad():
model.eval()
if isinstance(model, torch.nn.DataParallel):
# get rid of the DataParallel wrapper
model = model.module

generation_output = model.validate(val_iter, task)

# loss is already calculated
metrics_to_return = [metric for metric in task.metrics if metric != 'loss']

metrics = calculate_and_reduce_metrics(args, generation_output, metrics_to_return, model.tgt_lang)

results = {
'model prediction': generation_output.predictions,
'gold answer': generation_output.answers,
'context': generation_output.contexts,
}

print_results(results, num_print)

return generation_output, metrics


def do_validate(iteration, args, model, val_iters, *, train_task, round_progress, task_progress, writer, logger):
deca_score = 0
for val_task_idx, (val_task, val_iter) in enumerate(val_iters):
Expand Down
31 changes: 31 additions & 0 deletions genienlp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import random
import re
import shutil
import sys
import time
from json.decoder import JSONDecodeError
from typing import List, Optional
Expand Down Expand Up @@ -1013,3 +1014,33 @@ def replace_capturing_group(input, re_pattern, replacement):
else:
new_input = input
return new_input


def print_results(results, num_print):
print()

values = list(results.values())
num_examples = len(values[0])

# examples are sorted by length
# to get good diversity, get half of examples from second quartile
start = int(num_examples / 4)
end = start + int(num_print / 2)
first_list = [val[start:end] for val in values]

# and the other half from fourth quartile
start = int(3 * num_examples / 4)
end = start + num_print - int(num_print / 2)
second_list = [val[start:end] for val in values]

# join examples
processed_values = [first + second for first, second in zip(first_list, second_list)]

for ex_idx in range(len(processed_values[0])):
for key_idx, key in enumerate(results.keys()):
value = processed_values[key_idx][ex_idx]
v = value[0] if isinstance(value, list) else value
key_width = max(len(key) for key in results)
print(f'{key:>{key_width}}: {repr(v)}')
print()
sys.stdout.flush()
125 changes: 0 additions & 125 deletions genienlp/validate.py

This file was deleted.

0 comments on commit 8833a1b

Please sign in to comment.