diff --git a/comet/models/base.py b/comet/models/base.py index 5dc9266..003a5c5 100644 --- a/comet/models/base.py +++ b/comet/models/base.py @@ -144,6 +144,7 @@ def __init__( self.mc_dropout = False # Flag used to control usage of MC Dropout self.caching = False # Flag used to control Embedding Caching self.use_context = False + self.pool = pool # If not defined here, metrics will not live in the same device as our model. self.init_metrics() @@ -159,7 +160,7 @@ def set_mc_dropout(self, value: int): def enable_context(self): """Function that extends COMET to use preceding context as described in https://statmt.org/wmt22/pdf/2022.wmt-1.6.pdf.""" - self.use_context = True + logger.warning("Context can only be enabled for RegressionMetric with Average Pooling.") @abc.abstractmethod def read_training_data(self) -> List[dict]: diff --git a/comet/models/regression/referenceless.py b/comet/models/regression/referenceless.py index 778f4d8..2a22951 100644 --- a/comet/models/regression/referenceless.py +++ b/comet/models/regression/referenceless.py @@ -126,6 +126,10 @@ def __init__( def requires_references(self) -> bool: return False + + def enable_context(self): + if self.pool == "avg": + self.use_context = True def prepare_sample( self, sample: List[Dict[str, Union[str, float]]], stage: str = "train" diff --git a/comet/models/regression/regression_metric.py b/comet/models/regression/regression_metric.py index 30a4dc7..8cf4903 100644 --- a/comet/models/regression/regression_metric.py +++ b/comet/models/regression/regression_metric.py @@ -215,6 +215,10 @@ def prepare_sample( return model_inputs, targets + def enable_context(self): + if self.pool == "avg": + self.use_context = True + def estimate( self, src_sentemb: torch.Tensor, diff --git a/tests/unit/test_models_predict.py b/tests/unit/test_models_predict.py index fc35f5f..a1d0d1b 100644 --- a/tests/unit/test_models_predict.py +++ b/tests/unit/test_models_predict.py @@ -26,7 +26,7 @@ ] CONTEXT_TEST_SAMPLES = [ - {"lp": "it-en", "context_src": "", "src": "Le isole dell'Africa orientale sono situate nell'Oceano Indiano, al largo della costa est dell'Africa.", "context_mt": "", "mt": "The East African islands are located in the Indian Ocean, off the east coast of Africa.", "context_ref": "", "ref": "The East African Islands are in the Indian Ocean off the eastern coast of Africa.", "annotations": [], "score": 1.1697086095809937}, + {"lp": "it-en", "context_src": None, "src": "Le isole dell'Africa orientale sono situate nell'Oceano Indiano, al largo della costa est dell'Africa.", "context_mt": None, "mt": "The East African islands are located in the Indian Ocean, off the east coast of Africa.", "context_ref": None, "ref": "The East African Islands are in the Indian Ocean off the eastern coast of Africa.", "annotations": [], "score": 1.0}, ] class TestUnifiedMetricPredict(unittest.TestCase): @@ -55,15 +55,8 @@ def test_predict(self): def test_context_predict(self): self.model.enable_context() - assert self.model.use_context == True - for sample in CONTEXT_TEST_SAMPLES: - for key in ["src", "mt", "ref"]: - sample[key] = " {} ".format(self.model.encoder.tokenizer.sep_token).join([sample[f"context_{key}"], sample[key]]) - - model_output = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=12, gpus=self.gpus) + assert self.model.use_context == False - np.testing.assert_almost_equal([sample['score'] for sample in CONTEXT_TEST_SAMPLES], np.array(model_output.scores), decimal=5) - def test_length_batching(self): output_without_length_batching = self.model.predict(TEST_SAMPLES, batch_size=1, gpus=self.gpus, length_batching=False) output_with_length_batching = self.model.predict(TEST_SAMPLES, batch_size=1, gpus=self.gpus, length_batching=True) @@ -93,4 +86,17 @@ def test_xcomet_predict(self): model.score_weights = [0, 0, 0, 1] model_output = model.predict(TEST_SAMPLES, batch_size=12, gpus=self.gpus) self.assertListEqual(model_output.scores, model_output.metadata.mqm_scores) - \ No newline at end of file + + +class TestRegressionMetricPredict(unittest.TestCase): + + model = load_from_checkpoint(download_model("Unbabel/eamt22-cometinho-da", saving_directory=DATA_PATH)) + gpus = 1 if torch.cuda.device_count() > 0 else 0 + + def test_context_predict(self): + # Enabling context should not change scores" + model_output_context_disabled = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=2, gpus=self.gpus) + self.model.enable_context() + assert self.model.use_context == True + model_output_context_enabled = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=2, gpus=self.gpus) + np.testing.assert_almost_equal(np.array(model_output_context_disabled.scores), np.array(model_output_context_enabled.scores), decimal=5)