From f27b5473e4301abb9d6e623408bf58ae293e517f Mon Sep 17 00:00:00 2001 From: Sweta Agrawal Date: Wed, 1 May 2024 16:55:49 +0000 Subject: [PATCH 1/2] Enable DocCOMET version of COMET --- comet/cli/score.py | 7 ++++ comet/models/base.py | 8 ++++ comet/models/pooling_utils.py | 69 +++++++++++++++++++++++++++++-- tests/unit/test_models_predict.py | 15 +++++++ 4 files changed, 96 insertions(+), 3 deletions(-) diff --git a/comet/cli/score.py b/comet/cli/score.py index a18f7d2..d8146d2 100644 --- a/comet/cli/score.py +++ b/comet/cli/score.py @@ -30,6 +30,7 @@ --batch_size BATCH_SIZE (type: int, default: 16) --gpus GPUS (type: int, default: 1) + --enable-context Enables contextual extension of COMET. (default: False) --quiet Sets all loggers to ERROR level. (default: False) --only_system Prints only the final system score. (default: False) --to_json TO_JSON Exports results to a json file. (type: str, default: "") @@ -75,6 +76,9 @@ def score_command() -> None: parser.add_argument( "--quiet", action="store_true", help="Sets all loggers to ERROR level." ) + parser.add_argument( + "--enable-context", action="store_true", help="Enables contextual extension of COMET on inputs preprocessed with context information." + ) parser.add_argument( "--only_system", action="store_true", help="Prints only the final system score." ) @@ -159,6 +163,9 @@ def score_command() -> None: model = load_from_checkpoint(model_path) model.eval() + if cfg.enable_context: + model.enable_context() + if model.requires_references() and (cfg.references is None): parser.error( "{} requires -r/--references or -d/--sacrebleu_dataset.".format(cfg.model) diff --git a/comet/models/base.py b/comet/models/base.py index a730704..5dc9266 100644 --- a/comet/models/base.py +++ b/comet/models/base.py @@ -143,6 +143,7 @@ def __init__( self.nr_frozen_epochs = self.hparams.nr_frozen_epochs self.mc_dropout = False # Flag used to control usage of MC Dropout self.caching = False # Flag used to control Embedding Caching + self.use_context = False # If not defined here, metrics will not live in the same device as our model. self.init_metrics() @@ -155,6 +156,11 @@ def set_mc_dropout(self, value: int): """ self.mc_dropout = value + def enable_context(self): + """Function that extends COMET to use preceding context as described in + https://statmt.org/wmt22/pdf/2022.wmt-1.6.pdf.""" + self.use_context = True + @abc.abstractmethod def read_training_data(self) -> List[dict]: """Abstract method that reads the training data. @@ -343,6 +349,8 @@ def compute_sentence_embedding( embeddings, attention_mask, self.encoder.tokenizer.pad_token_id, + self.encoder.tokenizer.sep_token_id, + self.use_context, ) elif self.hparams.pool == "cls": diff --git a/comet/models/pooling_utils.py b/comet/models/pooling_utils.py index 9084283..7c3d8ed 100644 --- a/comet/models/pooling_utils.py +++ b/comet/models/pooling_utils.py @@ -13,13 +13,43 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from typing import List, Union +# From https://github.com/amazon-science/doc-mt-metrics/blob/5385cc28930aae9924edcb3201645dd3810b12c0/COMET/comet/models/pooling_utils.py#L18 +def find_start_inds_and_mask_tokens( + mask: torch.Tensor, + tokens: torch.Tensor, + separator_index: int, +) -> Union[List[int], torch.Tensor]: + """Finds the starting indices of each sentence for multi-sentence sequences and + creates a new mask to omit all context sentences from the pooling function. + + Args: + mask: Padding mask [batch_size x seq_length] + tokens: Word ids [batch_size x seq_length] + separator_index: Separator token index. + """ + start_inds = [] + ctx_mask = mask + for i, sent in enumerate(tokens): + # find all separator tokens in the sequence + separators = (sent == separator_index).nonzero() + if len(separators) > 1: + # if there are more than one find where the last sentence starts + ind = separators[-2].cpu().numpy().item() + start_inds.append(ind) + ctx_mask[i, 1:ind+1] = 0 + else: + start_inds.append(0) + return start_inds, ctx_mask def average_pooling( tokens: torch.Tensor, embeddings: torch.Tensor, mask: torch.Tensor, padding_index: int, + separator_index: int, + enable_context: bool = False ) -> torch.Tensor: """Average pooling method. @@ -33,9 +63,15 @@ def average_pooling( Return: torch.Tensor: Sentence embedding """ - wordemb = mask_fill(0.0, tokens, embeddings, padding_index) - sentemb = torch.sum(wordemb, 1) - sum_mask = mask.unsqueeze(-1).expand(embeddings.size()).float().sum(1) + if enable_context: + start_inds, ctx_mask = find_start_inds_and_mask_tokens(mask, tokens, separator_index) + wordemb = mask_fill_index(0.0, tokens, embeddings, start_inds, padding_index) + sentemb = torch.sum(wordemb, 1) + sum_mask = ctx_mask.unsqueeze(-1).expand(embeddings.size()).float().sum(1) + else: + wordemb = mask_fill(0.0, tokens, embeddings, padding_index) + sentemb = torch.sum(wordemb, 1) + sum_mask = mask.unsqueeze(-1).expand(embeddings.size()).float().sum(1) return sentemb / sum_mask @@ -55,6 +91,33 @@ def max_pooling( """ return mask_fill(float("-inf"), tokens, embeddings, padding_index).max(dim=1)[0] +# From https://github.com/amazon-science/doc-mt-metrics/blob/5385cc28930aae9924edcb3201645dd3810b12c0/COMET/comet/models/pooling_utils.py#L18 +def mask_fill_index( + fill_value: float, + tokens: torch.Tensor, + embeddings: torch.Tensor, + start_inds: list, + padding_index: int, +) -> torch.Tensor: + """ + Masks embeddings representing padded elements and context sentences for multi-sentence sequences. + + Args: + fill_value: the value to fill the embeddings belonging to padded tokens. + tokens: The input sequences [bsz x seq_len]. + embeddings: word embeddings [bsz x seq_len x hiddens]. + start_inds: Start of sentence indices. + padding_index: Index of the padding token. + + Return: + torch.Tensor: Sentence embedding + """ + padding_mask = tokens.eq(padding_index).unsqueeze(-1) + padding_maks2 = torch.zeros(tokens.shape, dtype=torch.bool, device=padding_mask.device) + for i, start in enumerate(start_inds): + padding_maks2[i, 1: start+1] = True + padding_mask = torch.logical_or(padding_mask, padding_maks2.unsqueeze(-1)) + return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings) def mask_fill( fill_value: float, diff --git a/tests/unit/test_models_predict.py b/tests/unit/test_models_predict.py index b4acb1f..fc35f5f 100644 --- a/tests/unit/test_models_predict.py +++ b/tests/unit/test_models_predict.py @@ -25,6 +25,10 @@ {"lp": "it-en", "src": "Nel suo uso popolare, il termine safari fa riferimento a viaggi svolti via terra, in particolare nella savana, per vedere la bellissima fauna selvatica africana.", "mt": "In its popular usage, the term safari refers to trips made by land, particularly in the savannah, to see beautiful African wildlife.", "ref": "The term safari in popular use refers to overland travel to view the stunning African wildlife, particularly on savanna.", "annotations": [], "score": 1.0} ] +CONTEXT_TEST_SAMPLES = [ + {"lp": "it-en", "context_src": "", "src": "Le isole dell'Africa orientale sono situate nell'Oceano Indiano, al largo della costa est dell'Africa.", "context_mt": "", "mt": "The East African islands are located in the Indian Ocean, off the east coast of Africa.", "context_ref": "", "ref": "The East African Islands are in the Indian Ocean off the eastern coast of Africa.", "annotations": [], "score": 1.1697086095809937}, +] + class TestUnifiedMetricPredict(unittest.TestCase): model = load_from_checkpoint(download_model("Unbabel/test-model-whimsical-whisper", saving_directory=DATA_PATH)) @@ -49,6 +53,17 @@ def test_predict(self): np.testing.assert_almost_equal(expected_scores, np.array(model_output.scores), decimal=5) np.testing.assert_almost_equal(model_output.system_score, expected_scores.mean(), 5) + def test_context_predict(self): + self.model.enable_context() + assert self.model.use_context == True + for sample in CONTEXT_TEST_SAMPLES: + for key in ["src", "mt", "ref"]: + sample[key] = " {} ".format(self.model.encoder.tokenizer.sep_token).join([sample[f"context_{key}"], sample[key]]) + + model_output = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=12, gpus=self.gpus) + + np.testing.assert_almost_equal([sample['score'] for sample in CONTEXT_TEST_SAMPLES], np.array(model_output.scores), decimal=5) + def test_length_batching(self): output_without_length_batching = self.model.predict(TEST_SAMPLES, batch_size=1, gpus=self.gpus, length_batching=False) output_with_length_batching = self.model.predict(TEST_SAMPLES, batch_size=1, gpus=self.gpus, length_batching=True) From c9e99e31c54a117f17049aee00095ef17df91388 Mon Sep 17 00:00:00 2001 From: Sweta Date: Sat, 4 May 2024 14:58:32 +0000 Subject: [PATCH 2/2] Enable context only for RegressionMetric and with average pooling --- comet/models/base.py | 3 ++- comet/models/regression/referenceless.py | 4 +++ comet/models/regression/regression_metric.py | 4 +++ tests/unit/test_models_predict.py | 26 ++++++++++++-------- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/comet/models/base.py b/comet/models/base.py index 5dc9266..003a5c5 100644 --- a/comet/models/base.py +++ b/comet/models/base.py @@ -144,6 +144,7 @@ def __init__( self.mc_dropout = False # Flag used to control usage of MC Dropout self.caching = False # Flag used to control Embedding Caching self.use_context = False + self.pool = pool # If not defined here, metrics will not live in the same device as our model. self.init_metrics() @@ -159,7 +160,7 @@ def set_mc_dropout(self, value: int): def enable_context(self): """Function that extends COMET to use preceding context as described in https://statmt.org/wmt22/pdf/2022.wmt-1.6.pdf.""" - self.use_context = True + logger.warning("Context can only be enabled for RegressionMetric with Average Pooling.") @abc.abstractmethod def read_training_data(self) -> List[dict]: diff --git a/comet/models/regression/referenceless.py b/comet/models/regression/referenceless.py index 778f4d8..2a22951 100644 --- a/comet/models/regression/referenceless.py +++ b/comet/models/regression/referenceless.py @@ -126,6 +126,10 @@ def __init__( def requires_references(self) -> bool: return False + + def enable_context(self): + if self.pool == "avg": + self.use_context = True def prepare_sample( self, sample: List[Dict[str, Union[str, float]]], stage: str = "train" diff --git a/comet/models/regression/regression_metric.py b/comet/models/regression/regression_metric.py index 30a4dc7..8cf4903 100644 --- a/comet/models/regression/regression_metric.py +++ b/comet/models/regression/regression_metric.py @@ -215,6 +215,10 @@ def prepare_sample( return model_inputs, targets + def enable_context(self): + if self.pool == "avg": + self.use_context = True + def estimate( self, src_sentemb: torch.Tensor, diff --git a/tests/unit/test_models_predict.py b/tests/unit/test_models_predict.py index fc35f5f..a1d0d1b 100644 --- a/tests/unit/test_models_predict.py +++ b/tests/unit/test_models_predict.py @@ -26,7 +26,7 @@ ] CONTEXT_TEST_SAMPLES = [ - {"lp": "it-en", "context_src": "", "src": "Le isole dell'Africa orientale sono situate nell'Oceano Indiano, al largo della costa est dell'Africa.", "context_mt": "", "mt": "The East African islands are located in the Indian Ocean, off the east coast of Africa.", "context_ref": "", "ref": "The East African Islands are in the Indian Ocean off the eastern coast of Africa.", "annotations": [], "score": 1.1697086095809937}, + {"lp": "it-en", "context_src": None, "src": "Le isole dell'Africa orientale sono situate nell'Oceano Indiano, al largo della costa est dell'Africa.", "context_mt": None, "mt": "The East African islands are located in the Indian Ocean, off the east coast of Africa.", "context_ref": None, "ref": "The East African Islands are in the Indian Ocean off the eastern coast of Africa.", "annotations": [], "score": 1.0}, ] class TestUnifiedMetricPredict(unittest.TestCase): @@ -55,15 +55,8 @@ def test_predict(self): def test_context_predict(self): self.model.enable_context() - assert self.model.use_context == True - for sample in CONTEXT_TEST_SAMPLES: - for key in ["src", "mt", "ref"]: - sample[key] = " {} ".format(self.model.encoder.tokenizer.sep_token).join([sample[f"context_{key}"], sample[key]]) - - model_output = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=12, gpus=self.gpus) + assert self.model.use_context == False - np.testing.assert_almost_equal([sample['score'] for sample in CONTEXT_TEST_SAMPLES], np.array(model_output.scores), decimal=5) - def test_length_batching(self): output_without_length_batching = self.model.predict(TEST_SAMPLES, batch_size=1, gpus=self.gpus, length_batching=False) output_with_length_batching = self.model.predict(TEST_SAMPLES, batch_size=1, gpus=self.gpus, length_batching=True) @@ -93,4 +86,17 @@ def test_xcomet_predict(self): model.score_weights = [0, 0, 0, 1] model_output = model.predict(TEST_SAMPLES, batch_size=12, gpus=self.gpus) self.assertListEqual(model_output.scores, model_output.metadata.mqm_scores) - \ No newline at end of file + + +class TestRegressionMetricPredict(unittest.TestCase): + + model = load_from_checkpoint(download_model("Unbabel/eamt22-cometinho-da", saving_directory=DATA_PATH)) + gpus = 1 if torch.cuda.device_count() > 0 else 0 + + def test_context_predict(self): + # Enabling context should not change scores" + model_output_context_disabled = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=2, gpus=self.gpus) + self.model.enable_context() + assert self.model.use_context == True + model_output_context_enabled = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=2, gpus=self.gpus) + np.testing.assert_almost_equal(np.array(model_output_context_disabled.scores), np.array(model_output_context_enabled.scores), decimal=5)