Skip to content

Commit

Permalink
Enable DocCOMET version of COMET
Browse files Browse the repository at this point in the history
  • Loading branch information
sweta20 committed May 1, 2024
1 parent 74ef715 commit f27b547
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 3 deletions.
7 changes: 7 additions & 0 deletions comet/cli/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
--batch_size BATCH_SIZE
(type: int, default: 16)
--gpus GPUS (type: int, default: 1)
--enable-context Enables contextual extension of COMET. (default: False)
--quiet Sets all loggers to ERROR level. (default: False)
--only_system Prints only the final system score. (default: False)
--to_json TO_JSON Exports results to a json file. (type: str, default: "")
Expand Down Expand Up @@ -75,6 +76,9 @@ def score_command() -> None:
parser.add_argument(
"--quiet", action="store_true", help="Sets all loggers to ERROR level."
)
parser.add_argument(
"--enable-context", action="store_true", help="Enables contextual extension of COMET on inputs preprocessed with context information."
)
parser.add_argument(
"--only_system", action="store_true", help="Prints only the final system score."
)
Expand Down Expand Up @@ -159,6 +163,9 @@ def score_command() -> None:
model = load_from_checkpoint(model_path)
model.eval()

if cfg.enable_context:
model.enable_context()

if model.requires_references() and (cfg.references is None):
parser.error(
"{} requires -r/--references or -d/--sacrebleu_dataset.".format(cfg.model)
Expand Down
8 changes: 8 additions & 0 deletions comet/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
self.nr_frozen_epochs = self.hparams.nr_frozen_epochs
self.mc_dropout = False # Flag used to control usage of MC Dropout
self.caching = False # Flag used to control Embedding Caching
self.use_context = False

# If not defined here, metrics will not live in the same device as our model.
self.init_metrics()
Expand All @@ -155,6 +156,11 @@ def set_mc_dropout(self, value: int):
"""
self.mc_dropout = value

def enable_context(self):
"""Function that extends COMET to use preceding context as described in
https://statmt.org/wmt22/pdf/2022.wmt-1.6.pdf."""
self.use_context = True

@abc.abstractmethod
def read_training_data(self) -> List[dict]:
"""Abstract method that reads the training data.
Expand Down Expand Up @@ -343,6 +349,8 @@ def compute_sentence_embedding(
embeddings,
attention_mask,
self.encoder.tokenizer.pad_token_id,
self.encoder.tokenizer.sep_token_id,
self.use_context,
)

elif self.hparams.pool == "cls":
Expand Down
69 changes: 66 additions & 3 deletions comet/models/pooling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import List, Union

# From https://github.com/amazon-science/doc-mt-metrics/blob/5385cc28930aae9924edcb3201645dd3810b12c0/COMET/comet/models/pooling_utils.py#L18
def find_start_inds_and_mask_tokens(
mask: torch.Tensor,
tokens: torch.Tensor,
separator_index: int,
) -> Union[List[int], torch.Tensor]:
"""Finds the starting indices of each sentence for multi-sentence sequences and
creates a new mask to omit all context sentences from the pooling function.
Args:
mask: Padding mask [batch_size x seq_length]
tokens: Word ids [batch_size x seq_length]
separator_index: Separator token index.
"""
start_inds = []
ctx_mask = mask
for i, sent in enumerate(tokens):
# find all separator tokens in the sequence
separators = (sent == separator_index).nonzero()
if len(separators) > 1:
# if there are more than one find where the last sentence starts
ind = separators[-2].cpu().numpy().item()
start_inds.append(ind)
ctx_mask[i, 1:ind+1] = 0
else:
start_inds.append(0)
return start_inds, ctx_mask

def average_pooling(
tokens: torch.Tensor,
embeddings: torch.Tensor,
mask: torch.Tensor,
padding_index: int,
separator_index: int,
enable_context: bool = False
) -> torch.Tensor:
"""Average pooling method.
Expand All @@ -33,9 +63,15 @@ def average_pooling(
Return:
torch.Tensor: Sentence embedding
"""
wordemb = mask_fill(0.0, tokens, embeddings, padding_index)
sentemb = torch.sum(wordemb, 1)
sum_mask = mask.unsqueeze(-1).expand(embeddings.size()).float().sum(1)
if enable_context:
start_inds, ctx_mask = find_start_inds_and_mask_tokens(mask, tokens, separator_index)
wordemb = mask_fill_index(0.0, tokens, embeddings, start_inds, padding_index)
sentemb = torch.sum(wordemb, 1)
sum_mask = ctx_mask.unsqueeze(-1).expand(embeddings.size()).float().sum(1)
else:
wordemb = mask_fill(0.0, tokens, embeddings, padding_index)
sentemb = torch.sum(wordemb, 1)
sum_mask = mask.unsqueeze(-1).expand(embeddings.size()).float().sum(1)
return sentemb / sum_mask


Expand All @@ -55,6 +91,33 @@ def max_pooling(
"""
return mask_fill(float("-inf"), tokens, embeddings, padding_index).max(dim=1)[0]

# From https://github.com/amazon-science/doc-mt-metrics/blob/5385cc28930aae9924edcb3201645dd3810b12c0/COMET/comet/models/pooling_utils.py#L18
def mask_fill_index(
fill_value: float,
tokens: torch.Tensor,
embeddings: torch.Tensor,
start_inds: list,
padding_index: int,
) -> torch.Tensor:
"""
Masks embeddings representing padded elements and context sentences for multi-sentence sequences.
Args:
fill_value: the value to fill the embeddings belonging to padded tokens.
tokens: The input sequences [bsz x seq_len].
embeddings: word embeddings [bsz x seq_len x hiddens].
start_inds: Start of sentence indices.
padding_index: Index of the padding token.
Return:
torch.Tensor: Sentence embedding
"""
padding_mask = tokens.eq(padding_index).unsqueeze(-1)
padding_maks2 = torch.zeros(tokens.shape, dtype=torch.bool, device=padding_mask.device)
for i, start in enumerate(start_inds):
padding_maks2[i, 1: start+1] = True
padding_mask = torch.logical_or(padding_mask, padding_maks2.unsqueeze(-1))
return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings)

def mask_fill(
fill_value: float,
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_models_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
{"lp": "it-en", "src": "Nel suo uso popolare, il termine safari fa riferimento a viaggi svolti via terra, in particolare nella savana, per vedere la bellissima fauna selvatica africana.", "mt": "In its popular usage, the term safari refers to trips made by land, particularly in the savannah, to see beautiful African wildlife.", "ref": "The term safari in popular use refers to overland travel to view the stunning African wildlife, particularly on savanna.", "annotations": [], "score": 1.0}
]

CONTEXT_TEST_SAMPLES = [
{"lp": "it-en", "context_src": "", "src": "Le isole dell'Africa orientale sono situate nell'Oceano Indiano, al largo della costa est dell'Africa.", "context_mt": "", "mt": "The East African islands are located in the Indian Ocean, off the east coast of Africa.", "context_ref": "", "ref": "The East African Islands are in the Indian Ocean off the eastern coast of Africa.", "annotations": [], "score": 1.1697086095809937},
]

class TestUnifiedMetricPredict(unittest.TestCase):

model = load_from_checkpoint(download_model("Unbabel/test-model-whimsical-whisper", saving_directory=DATA_PATH))
Expand All @@ -49,6 +53,17 @@ def test_predict(self):
np.testing.assert_almost_equal(expected_scores, np.array(model_output.scores), decimal=5)
np.testing.assert_almost_equal(model_output.system_score, expected_scores.mean(), 5)

def test_context_predict(self):
self.model.enable_context()
assert self.model.use_context == True
for sample in CONTEXT_TEST_SAMPLES:
for key in ["src", "mt", "ref"]:
sample[key] = " {} ".format(self.model.encoder.tokenizer.sep_token).join([sample[f"context_{key}"], sample[key]])

model_output = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=12, gpus=self.gpus)

np.testing.assert_almost_equal([sample['score'] for sample in CONTEXT_TEST_SAMPLES], np.array(model_output.scores), decimal=5)

def test_length_batching(self):
output_without_length_batching = self.model.predict(TEST_SAMPLES, batch_size=1, gpus=self.gpus, length_batching=False)
output_with_length_batching = self.model.predict(TEST_SAMPLES, batch_size=1, gpus=self.gpus, length_batching=True)
Expand Down

0 comments on commit f27b547

Please sign in to comment.