Skip to content

Commit

Permalink
Merge branch 'pr-226'
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardorei committed Nov 27, 2024
2 parents 332dfb0 + 130e541 commit 180bfe3
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 33 deletions.
21 changes: 16 additions & 5 deletions comet/encoders/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,29 @@ class BERTEncoder(Encoder):
pretrained_model (str): Pretrained model from hugging face.
load_pretrained_weights (bool): If set to True loads the pretrained weights
from Hugging Face
local_files_only (bool): Whether or not to only look at local files.
"""

def __init__(
self, pretrained_model: str, load_pretrained_weights: bool = True
self,
pretrained_model: str,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> None:
super().__init__()
self.tokenizer = BertTokenizerFast.from_pretrained(
pretrained_model, use_fast=True
pretrained_model, use_fast=True, local_files_only=local_files_only
)
if load_pretrained_weights:
self.model = BertModel.from_pretrained(
pretrained_model, add_pooling_layer=False
)
else:
self.model = BertModel(
BertConfig.from_pretrained(pretrained_model), add_pooling_layer=False
BertConfig.from_pretrained(
pretrained_model, local_files_only=local_files_only
),
add_pooling_layer=False,
)
self.model.encoder.output_hidden_states = True

Expand Down Expand Up @@ -87,17 +94,21 @@ def uses_token_type_ids(self) -> bool:

@classmethod
def from_pretrained(
cls, pretrained_model: str, load_pretrained_weights: bool = True
cls,
pretrained_model: str,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> Encoder:
"""Function that loads a pretrained encoder from Hugging Face.
Args:
pretrained_model (str):Name of the pretrain model to be loaded.
load_pretrained_weights (bool): If set to True loads the pretrained weights
from Hugging Face
local_files_only (bool): Whether or not to only look at local files.
Returns:
Encoder: XLMREncoder object.
"""
return BERTEncoder(pretrained_model, load_pretrained_weights)
return BERTEncoder(pretrained_model, load_pretrained_weights, local_files_only)

def freeze_embeddings(self) -> None:
"""Frezees the embedding layer."""
Expand Down
24 changes: 19 additions & 5 deletions comet/encoders/minilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,48 @@ class MiniLMEncoder(XLMREncoder):
pretrained_model (str): Pretrained model from hugging face.
load_pretrained_weights (bool): If set to True loads the pretrained weights
from Hugging Face
local_files_only (bool): Whether or not to only look at local files.
"""

def __init__(
self, pretrained_model: str, load_pretrained_weights: bool = True
self,
pretrained_model: str,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> None:
super(Encoder, self).__init__()
self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(
"xlm-roberta-base", use_fast=True
"xlm-roberta-base", use_fast=True, local_files_only=local_files_only
)
if load_pretrained_weights:
self.model = BertModel.from_pretrained(pretrained_model)
else:
self.model = BertModel(BertConfig.from_pretrained(pretrained_model))
self.model = BertModel(
BertConfig.from_pretrained(
pretrained_model, local_files_only=local_files_only
)
)

self.model.encoder.output_hidden_states = True

@classmethod
def from_pretrained(
cls, pretrained_model: str, load_pretrained_weights: bool = True
cls,
pretrained_model: str,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> Encoder:
"""Function that loads a pretrained encoder from Hugging Face.
Args:
pretrained_model (str):Name of the pretrain model to be loaded.
load_pretrained_weights (bool): If set to True loads the pretrained weights
from Hugging Face
local_files_only (bool): Whether or not to only look at local files.
Returns:
Encoder: XLMREncoder object.
"""
return MiniLMEncoder(pretrained_model, load_pretrained_weights)
return MiniLMEncoder(
pretrained_model, load_pretrained_weights, local_files_only
)
24 changes: 19 additions & 5 deletions comet/encoders/rembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,27 @@ class RemBERTEncoder(XLMREncoder):
pretrained_model (str): Pretrained model from hugging face.
load_pretrained_weights (bool): If set to True loads the pretrained weights
from Hugging Face
local_files_only (bool): Whether or not to only look at local files.
"""

def __init__(
self, pretrained_model: str, load_pretrained_weights: bool = True
self,
pretrained_model: str,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> None:
super(Encoder, self).__init__()
self.tokenizer = RemBertTokenizerFast.from_pretrained(
pretrained_model, use_fast=True
pretrained_model, use_fast=True, local_files_only=local_files_only
)
if load_pretrained_weights:
self.model = RemBertModel.from_pretrained(pretrained_model)
else:
self.model = RemBertModel(RemBertConfig.from_pretrained(pretrained_model))
self.model = RemBertModel(
RemBertConfig.from_pretrained(
pretrained_model, local_files_only=local_files_only
)
)

self.model.encoder.output_hidden_states = True

Expand All @@ -57,16 +65,22 @@ def uses_token_type_ids(self):

@classmethod
def from_pretrained(
cls, pretrained_model: str, load_pretrained_weights: bool = True
cls,
pretrained_model: str,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> Encoder:
"""Function that loads a pretrained encoder from Hugging Face.
Args:
pretrained_model (str): Name of the pretrain model to be loaded.
load_pretrained_weights (bool): If set to True loads the pretrained weights
from Hugging Face
local_files_only (bool): Whether or not to only look at local files.
Returns:
Encoder: XLMRXLEncoder object.
"""
return RemBERTEncoder(pretrained_model, load_pretrained_weights)
return RemBERTEncoder(
pretrained_model, load_pretrained_weights, local_files_only
)
22 changes: 17 additions & 5 deletions comet/encoders/xlmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,28 @@ class XLMREncoder(BERTEncoder):
pretrained_model (str): Pretrained model from hugging face.
load_pretrained_weights (bool): If set to True loads the pretrained weights
from Hugging Face
local_files_only (bool): Whether or not to only look at local files.
"""

def __init__(
self, pretrained_model: str, load_pretrained_weights: bool = True
self,
pretrained_model: str,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> None:
super(Encoder, self).__init__()
self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(pretrained_model)
self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(
pretrained_model, local_files_only=local_files_only
)
if load_pretrained_weights:
self.model = XLMRobertaModel.from_pretrained(
pretrained_model, add_pooling_layer=False
)
else:
self.model = XLMRobertaModel(
XLMRobertaConfig.from_pretrained(pretrained_model),
XLMRobertaConfig.from_pretrained(
pretrained_model, local_files_only=local_files_only
),
add_pooling_layer=False,
)
self.model.encoder.output_hidden_states = True
Expand All @@ -63,19 +71,23 @@ def uses_token_type_ids(self):

@classmethod
def from_pretrained(
cls, pretrained_model: str, load_pretrained_weights: bool = True
cls,
pretrained_model: str,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> Encoder:
"""Function that loads a pretrained encoder from Hugging Face.
Args:
pretrained_model (str):Name of the pretrain model to be loaded.
load_pretrained_weights (bool): If set to True loads the pretrained weights
from Hugging Face
local_files_only (bool): Whether or not to only look at local files.
Returns:
Encoder: XLMREncoder object.
"""
return XLMREncoder(pretrained_model, load_pretrained_weights)
return XLMREncoder(pretrained_model, load_pretrained_weights, local_files_only)

def forward(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
Expand Down
24 changes: 19 additions & 5 deletions comet/encoders/xlmr_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,50 @@ class XLMRXLEncoder(XLMREncoder):
pretrained_model (str): Pretrained model from hugging face.
load_pretrained_weights (bool): If set to True loads the pretrained weights
from Hugging Face
local_files_only (bool): Whether or not to only look at local files.
"""

def __init__(
self, pretrained_model: str, load_pretrained_weights: bool = True
self,
pretrained_model: str,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> None:
super(Encoder, self).__init__()
self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(pretrained_model)
self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(
pretrained_model, local_files_only=local_files_only
)
if load_pretrained_weights:
self.model = XLMRobertaXLModel.from_pretrained(
pretrained_model, add_pooling_layer=False
)
else:
self.model = XLMRobertaXLModel(
XLMRobertaXLConfig.from_pretrained(pretrained_model),
XLMRobertaXLConfig.from_pretrained(
pretrained_model, local_files_only=local_files_only
),
add_pooling_layer=False,
)
self.model.encoder.output_hidden_states = True

@classmethod
def from_pretrained(
cls, pretrained_model: str, load_pretrained_weights: bool = True
cls,
pretrained_model: str,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> Encoder:
"""Function that loads a pretrained encoder from Hugging Face.
Args:
pretrained_model (str): Name of the pretrain model to be loaded.
load_pretrained_weights (bool): If set to True loads the pretrained weights
from Hugging Face
local_files_only (bool): Whether or not to only look at local files.
Returns:
Encoder: XLMRXLEncoder object.
"""
return XLMRXLEncoder(pretrained_model, load_pretrained_weights)
return XLMRXLEncoder(
pretrained_model, load_pretrained_weights, local_files_only
)
10 changes: 9 additions & 1 deletion comet/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ def download_model(


def load_from_checkpoint(
checkpoint_path: str, reload_hparams: bool = False, strict: bool = False
checkpoint_path: str,
reload_hparams: bool = False,
strict: bool = False,
local_files_only: bool = False,
) -> CometModel:
"""Loads models from a checkpoint path.
Expand All @@ -70,6 +73,10 @@ def load_from_checkpoint(
to True all hparams will be reloaded.
strict (bool): Strictly enforce that the keys in checkpoint_path match the
keys returned by this module's state dict. Defaults to False
local_files_only (bool): Whether or not to only look at local files.
Make sure `pretrained_model` in checkpoint `hparams.yaml` is
downloaded beforehand. (e.g. `xlm-roberta-large` for
`Unbabel/wmt22-cometkiwi-da`)
Return:
COMET model.
"""
Expand All @@ -91,6 +98,7 @@ def load_from_checkpoint(
hparams_file=hparams_file if reload_hparams else None,
map_location=torch.device("cpu"),
strict=strict,
local_files_only=local_files_only,
)
return model
else:
Expand Down
4 changes: 3 additions & 1 deletion comet/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class CometModel(ptl.LightningModule, metaclass=abc.ABCMeta):
Validation results are averaged across validation set. Defaults to None.
load_pretrained_weights (Bool): If set to False it avoids loading the weights
of the pretrained model (e.g. XLM-R) before it loads the COMET checkpoint
local_files_only (bool): Whether or not to only look at local files.
"""

def __init__(
Expand All @@ -113,11 +114,12 @@ def __init__(
validation_data: Optional[List[str]] = None,
class_identifier: Optional[str] = None,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> None:
super().__init__()
self.save_hyperparameters()
self.encoder = str2encoder[self.hparams.encoder_model].from_pretrained(
self.hparams.pretrained_model, load_pretrained_weights
self.hparams.pretrained_model, load_pretrained_weights, local_files_only
)

self.epoch_nr = 0
Expand Down
3 changes: 3 additions & 0 deletions comet/models/multitask/unified_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class UnifiedMetric(CometModel):
error_labels + weight for the default 'O' label. Defaults to None.
load_pretrained_weights (Bool): If set to False it avoids loading the weights
of the pretrained model (e.g. XLM-R) before it loads the COMET checkpoint
local_files_only (bool): Whether or not to only look at local files.
"""

def __init__(
Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(
error_labels: List[str] = ["minor", "major"],
cross_entropy_weights: Optional[List[float]] = None,
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> None:
super().__init__(
nr_frozen_epochs=nr_frozen_epochs,
Expand All @@ -139,6 +141,7 @@ def __init__(
validation_data=validation_data,
class_identifier="unified_metric",
load_pretrained_weights=load_pretrained_weights,
local_files_only=local_files_only,
)
self.save_hyperparameters()
self.estimator = FeedForward(
Expand Down
2 changes: 2 additions & 0 deletions comet/models/multitask/xcomet_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
loss_lambda: float = 0.055,
cross_entropy_weights: Optional[List[float]] = [0.08, 0.486, 0.505, 0.533],
load_pretrained_weights: bool = True,
local_files_only: bool = False,
) -> None:
super(UnifiedMetric, self).__init__(
nr_frozen_epochs=nr_frozen_epochs,
Expand All @@ -86,6 +87,7 @@ def __init__(
validation_data=validation_data,
class_identifier="xcomet_metric",
load_pretrained_weights=load_pretrained_weights,
local_files_only=local_files_only,
)
self.estimator = FeedForward(
in_dim=self.encoder.output_units,
Expand Down
Loading

0 comments on commit 180bfe3

Please sign in to comment.