Skip to content

Commit

Permalink
Enable context only for RegressionMetric and with average pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
sweta20 committed May 4, 2024
1 parent f27b547 commit c9e99e3
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
3 changes: 2 additions & 1 deletion comet/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
self.mc_dropout = False # Flag used to control usage of MC Dropout
self.caching = False # Flag used to control Embedding Caching
self.use_context = False
self.pool = pool

# If not defined here, metrics will not live in the same device as our model.
self.init_metrics()
Expand All @@ -159,7 +160,7 @@ def set_mc_dropout(self, value: int):
def enable_context(self):
"""Function that extends COMET to use preceding context as described in
https://statmt.org/wmt22/pdf/2022.wmt-1.6.pdf."""
self.use_context = True
logger.warning("Context can only be enabled for RegressionMetric with Average Pooling.")

@abc.abstractmethod
def read_training_data(self) -> List[dict]:
Expand Down
4 changes: 4 additions & 0 deletions comet/models/regression/referenceless.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def __init__(

def requires_references(self) -> bool:
return False

def enable_context(self):
if self.pool == "avg":
self.use_context = True

def prepare_sample(
self, sample: List[Dict[str, Union[str, float]]], stage: str = "train"
Expand Down
4 changes: 4 additions & 0 deletions comet/models/regression/regression_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ def prepare_sample(

return model_inputs, targets

def enable_context(self):
if self.pool == "avg":
self.use_context = True

def estimate(
self,
src_sentemb: torch.Tensor,
Expand Down
26 changes: 16 additions & 10 deletions tests/unit/test_models_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
]

CONTEXT_TEST_SAMPLES = [
{"lp": "it-en", "context_src": "", "src": "Le isole dell'Africa orientale sono situate nell'Oceano Indiano, al largo della costa est dell'Africa.", "context_mt": "", "mt": "The East African islands are located in the Indian Ocean, off the east coast of Africa.", "context_ref": "", "ref": "The East African Islands are in the Indian Ocean off the eastern coast of Africa.", "annotations": [], "score": 1.1697086095809937},
{"lp": "it-en", "context_src": None, "src": "Le isole dell'Africa orientale sono situate nell'Oceano Indiano, al largo della costa est dell'Africa.", "context_mt": None, "mt": "The East African islands are located in the Indian Ocean, off the east coast of Africa.", "context_ref": None, "ref": "The East African Islands are in the Indian Ocean off the eastern coast of Africa.", "annotations": [], "score": 1.0},
]

class TestUnifiedMetricPredict(unittest.TestCase):
Expand Down Expand Up @@ -55,15 +55,8 @@ def test_predict(self):

def test_context_predict(self):
self.model.enable_context()
assert self.model.use_context == True
for sample in CONTEXT_TEST_SAMPLES:
for key in ["src", "mt", "ref"]:
sample[key] = " {} ".format(self.model.encoder.tokenizer.sep_token).join([sample[f"context_{key}"], sample[key]])

model_output = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=12, gpus=self.gpus)
assert self.model.use_context == False

np.testing.assert_almost_equal([sample['score'] for sample in CONTEXT_TEST_SAMPLES], np.array(model_output.scores), decimal=5)

def test_length_batching(self):
output_without_length_batching = self.model.predict(TEST_SAMPLES, batch_size=1, gpus=self.gpus, length_batching=False)
output_with_length_batching = self.model.predict(TEST_SAMPLES, batch_size=1, gpus=self.gpus, length_batching=True)
Expand Down Expand Up @@ -93,4 +86,17 @@ def test_xcomet_predict(self):
model.score_weights = [0, 0, 0, 1]
model_output = model.predict(TEST_SAMPLES, batch_size=12, gpus=self.gpus)
self.assertListEqual(model_output.scores, model_output.metadata.mqm_scores)



class TestRegressionMetricPredict(unittest.TestCase):

model = load_from_checkpoint(download_model("Unbabel/eamt22-cometinho-da", saving_directory=DATA_PATH))
gpus = 1 if torch.cuda.device_count() > 0 else 0

def test_context_predict(self):
# Enabling context should not change scores"
model_output_context_disabled = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=2, gpus=self.gpus)
self.model.enable_context()
assert self.model.use_context == True
model_output_context_enabled = self.model.predict(CONTEXT_TEST_SAMPLES, batch_size=2, gpus=self.gpus)
np.testing.assert_almost_equal(np.array(model_output_context_disabled.scores), np.array(model_output_context_enabled.scores), decimal=5)

0 comments on commit c9e99e3

Please sign in to comment.