-
Notifications
You must be signed in to change notification settings - Fork 0
/
save_bart_outputs.py
112 lines (93 loc) · 4.68 KB
/
save_bart_outputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# SPDX-FileCopyrightText: 2022 Idiap Research Institute
#
# SPDX-License-Identifier: MIT
""" Loads a BART model and saves its outputs on validation/test data, grouped by refdoc. """
import glob
import os
from collections import defaultdict
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning.core.saving import load_hparams_from_yaml
from bart import BartSummarizer
from dataloader import SummarizationDataModule
def evaluation_step(model, hparams, batch):
beam_output = model.model.generate(
input_ids=batch.src,
attention_mask=batch.mask_src,
max_length=hparams.max_length,
min_length=hparams.min_length,
no_repeat_ngram_size=hparams.ngram_blocking,
length_penalty=hparams.length_penalty,
num_beams=5,
use_cache=True,
early_stopping=True,
)
# decode and save outputs
source = model.decode(batch.src[0].tolist())
reference = model.decode(batch.tgt[0].tolist())
candidate = model.decode(beam_output[0].tolist())
return source, reference, candidate
def main(args):
seed_everything(args.seed)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
checkpoints = glob.glob(os.path.join(args.model_dir, '*.ckpt'))
assert len(checkpoints) == 1
checkpoint_path = checkpoints[0]
# restore args from hparams file
hparams = load_hparams_from_yaml(os.path.join(args.model_dir, 'version_0', 'hparams.yaml'))
for key, value in hparams.items():
if not hasattr(args, key):
setattr(args, key, value)
# init model
model = BartSummarizer.load_from_checkpoint(args=args, checkpoint_path=checkpoint_path)
model.eval()
data_module = SummarizationDataModule(args)
data_loader = data_module.val_dataloader() if args.data_split == 'valid' else data_module.test_dataloader()
outputs = defaultdict(list)
refdocs_processed = set()
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
if batch_idx >= args.num_batches > 0:
break
refdoc = batch.refdoc[0]
if refdoc in refdocs_processed:
index = outputs['refdocs'].index(refdoc)
reference = model.decode(batch.tgt[0].tolist())
outputs['references'][index].append(reference)
continue
source, reference, candidate = evaluation_step(model, model.hparams, batch)
outputs['sources'].append(source)
outputs['references'].append([reference])
outputs['candidates'].append(candidate)
outputs['refdocs'].append(refdoc)
refdocs_processed.add(refdoc)
# write results to files
for name in ('sources', 'candidates', 'refdocs'):
with open(os.path.join(args.output_dir, f'{name}.txt'), 'w') as f:
for output in outputs[name]:
f.write(output + '\n')
with open(os.path.join(args.output_dir, 'references.txt'), 'w') as f:
for refdoc_references in outputs['references']:
f.write(args.reference_join.join(refdoc_references) + '\n')
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Save model outputs on validation/test data.')
parser.add_argument('--model_dir', default='models', help='Path to model directory')
parser.add_argument('--seed', default=1, help='Random seed')
# override training args
parser.add_argument('--num_workers', type=int, default=4, help='Num workers for data loading')
parser.add_argument('--gpus', type=int, default=0, help='How many GPUs to use.')
parser.add_argument('--limit_val_batches', type=float, default=1.0, help='Fraction of validation batches to sample')
parser.add_argument('--limit_test_batches', type=float, default=1.0, help='Fraction of test batches to sample')
# override generation args
parser.add_argument('--max_length', type=int, default=500, help='Max generation length')
parser.add_argument('--min_length', type=int, default=50, help='Min generation length')
parser.add_argument('--length_penalty', type=float, default=1.0, help='Alpha for length penalty')
parser.add_argument('--ngram_blocking', type=int, default=0, help='Block repetition of n-grams (0: off)')
# task args
parser.add_argument('--data_split', default='valid', choices=['valid', 'test'], help='Data split to use')
parser.add_argument('--output_dir', default='results', help='Path to output directory')
parser.add_argument('--num_batches', type=int, default=0, help='Number of batches (0: all)')
parser.add_argument('--reference_join', default='<ref>', help='Marker to join multiple references.')
main(parser.parse_args())