-
Notifications
You must be signed in to change notification settings - Fork 0
/
comet_mbr.py
238 lines (207 loc) · 9.01 KB
/
comet_mbr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# -*- coding: utf-8 -*-
# Copyright (C) 2020 Unbabel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##### Changes were made according to our implementation. This takes the input as json file and gives the output as a json file #####
##### Run command : comet-mbr -st [filename].jsonl -o [OUTFILE].jsonl #####
"""
Command for Minimum Bayes Risk Decoding.
========================================
This script is inspired in Chantal Amrhein script used in:
Title: Identifying Weaknesses in Machine Translation Metrics Through Minimum Bayes Risk Decoding: A Case Study for COMET
URL: https://arxiv.org/abs/2202.05148
optional arguments:
-h, --help Show this help message and exit.
-s SOURCES, --sources SOURCES
(type: Path_fr, default: null)
-t TRANSLATIONS, --translations TRANSLATIONS
(type: Path_fr, default: null)
--batch_size BATCH_SIZE
(type: int, default: 8)
--num_samples NUM_SAMPLES
(required, type: int)
--model MODEL COMET model to be used. (type: str, default: wmt20-comet-da)
--model_storage_path MODEL_STORAGE_PATH
Path to the directory where models will be stored. By default its saved in ~/.cache/torch/unbabel_comet/ (default: null)
-o OUTPUT, --output OUTPUT
Best candidates after running MBR decoding. (type: str, default: mbr_result.txt)
"""
import os, json
from typing import List, Tuple
import torch
from comet.download_utils import download_model
from comet.models import RegressionMetric, available_metrics, load_from_checkpoint
from jsonargparse import ArgumentParser
from jsonargparse.typing import Path_fr
from tqdm import tqdm
def build_embeddings(
sources: List[str],
translations: List[str],
model: RegressionMetric,
batch_size: int,
) -> Tuple[torch.Tensor]:
"""Tokenization and respective encoding of source and translation sentences using
a RegressionMetric model.
:param sources: List of source sentences.
:param translations: List of translation sentences.
:param model: RegressionMetric model that will be used to embed sentences.
:param batch_size: batch size used during encoding.
:return: source and MT embeddings.
"""
# TODO: Optimize this function to have faster MBR decoding!
src_batches = [
sources[i : i + batch_size] for i in range(0, len(sources), batch_size)
]
src_inputs = [model.encoder.prepare_sample(batch) for batch in src_batches]
mt_batches = [
translations[i : i + batch_size]
for i in range(0, len(translations), batch_size)
]
mt_inputs = [model.encoder.prepare_sample(batch) for batch in mt_batches]
src_embeddings = []
with torch.no_grad():
for batch in src_inputs:
input_ids = batch["input_ids"].to(model.device)
attention_mask = batch["attention_mask"].to(model.device)
src_embeddings.append(
model.get_sentence_embedding(input_ids, attention_mask)
)
src_embeddings = torch.vstack(src_embeddings)
mt_embeddings = []
with torch.no_grad():
for batch in tqdm(mt_inputs, desc="Encoding sentences...", dynamic_ncols=True):
input_ids = batch["input_ids"].to(model.device)
attention_mask = batch["attention_mask"].to(model.device)
mt_embeddings.append(
model.get_sentence_embedding(input_ids, attention_mask)
)
mt_embeddings = torch.vstack(mt_embeddings)
return src_embeddings, mt_embeddings
def mbr_decoding(
src_embeddings: torch.Tensor, mt_embeddings: torch.Tensor, model: RegressionMetric
) -> torch.Tensor:
"""Performs MBR Decoding for each translation for a given source.
:param src_embeddings: Embeddings of source sentences.
:param mt_embeddings: Embeddings of MT sentences.
:param model: RegressionMetric Model.
:return:
Returns a [n_sent x num_samples] matrix M where each line represents a source sentence
and each column a given sample.
M[i][j] is the MBR score of sample j for source i.
"""
n_sent, num_samples, _ = mt_embeddings.shape
mbr_matrix = torch.zeros(n_sent, num_samples)
with torch.no_grad():
# Loop over all source sentences
for i in tqdm(
range(mbr_matrix.shape[0]), desc="MBR Scores...", dynamic_ncols=True
):
source = src_embeddings[i, :].repeat(num_samples, 1)
# Loop over all hypothesis
for j in range(mbr_matrix.shape[1]):
translation = mt_embeddings[i, j, :].repeat(num_samples, 1)
# Score current hypothesis against all others
pseudo_refs = mt_embeddings[i, :]
scores = model.estimate(source, translation, pseudo_refs)[
"score"
].squeeze(1)
scores = torch.cat([scores[0:j], scores[j + 1 :]])
mbr_matrix[i, j] = scores.mean()
return mbr_matrix
def mbr_command() -> None:
parser = ArgumentParser(description="Command for Minimum Bayes Risk Decoding.")
parser.add_argument("-st", "--source", type=Path_fr, required=True)
#parser.add_argument("-t", "--translations", type=Path_fr, required=True)
parser.add_argument("--batch_size", type=int, default=1)
#parser.add_argument("--num_samples", type=int, required=True)
parser.add_argument(
"--model",
type=str,
required=False,
default="wmt20-comet-da",
#default="wmt21-comet-qe-mqm",
help="COMET model to be used.",
)
parser.add_argument(
"--model_storage_path",
help=(
"Path to the directory where models will be stored. "
+ "By default its saved in ~/.cache/torch/unbabel_comet/"
),
default=None,
)
parser.add_argument(
"-o",
"--output",
type=str,
required=True,
help="Best candidates after running MBR decoding.",
)
cfg = parser.parse_args()
if cfg.model.endswith(".ckpt") and os.path.exists(cfg.model):
model_path = cfg.model
elif cfg.model in available_metrics:
model_path = download_model(cfg.model, saving_directory=cfg.model_storage_path)
else:
parser.error(
"{} is not a valid checkpoint path or model choice. Choose from {}".format(
cfg.model, available_metrics.keys()
)
)
model = load_from_checkpoint(model_path)
if not isinstance(model, RegressionMetric) or model.is_referenceless():
raise Exception(
"Incorrect model ({}). MBR command only works with Reference-based Regression models!".format(
model.__class__.__name__
)
)
model.eval()
#model.cuda()
with open(cfg.output, "w") as fp:
for line in open(cfg.source(), 'r', encoding = 'utf-8'):
pairs = {}
#pairs.append(json.loads(line))
#print("\n# Total no of samples ", len(pairs))
#with open(, 'r') as fs:
#data = json.load(fs)
#all_sources = [line.strip() for line in fs.readlines()]
#all_translations = [line.strip().split(',') for line in ft.readlines()]
l = json.loads(line)
sources = [l["en_sent"]]
translations = [sent.strip("'\"") for sent in l["te_sent"]]
src_embeddings, mt_embeddings = build_embeddings(
sources, translations, model, cfg.batch_size
)
print (src_embeddings.shape, mt_embeddings.shape)
#mt_embeddings = mt_embeddings.reshape(len(sources), cfg.num_samples, -1)
mt_embeddings = mt_embeddings.reshape(len(sources), len(translations), -1)
mbr_matrix = mbr_decoding(src_embeddings, mt_embeddings, model)
'''
translations = [
translations[i : i + cfg.num_samples]
for i in range(0, len(translations), cfg.num_samples)
]
'''
translations = [translations]
assert len(sources) == len(translations)
best_candidates = []
for i, samples in enumerate(translations):
best_cand_idx = torch.argmax(mbr_matrix[i, :])
best_candidates.append(samples[best_cand_idx])
pairs["id"] = l["id"]
pairs["en_sent"] = l["en_sent"]
pairs["te_sent"] = best_candidates[0]
#for sample in best_candidates:
fp.write(str(pairs) + "\n")
if __name__ == "__main__":
mbr_command()