Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ORT from transformers #66

Merged
merged 13 commits into from
Feb 19, 2025
1 change: 0 additions & 1 deletion examples/pytorch/translation/run_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,6 @@ def main():
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
ort=True if training_args.ort else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,6 @@ class PretrainedConfig(PushToHubMixin):
loss_type (`str`, *optional*):
The type of loss that the model should use. It should be in `LOSS_MAPPING`'s keys, otherwise the loss will
be automatically infered from the model architecture.

Onnxruntime specific parameters

- **ort** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use ORT.
"""

model_type: str = ""
Expand Down Expand Up @@ -222,7 +218,6 @@ def __init__(self, **kwargs):
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
self.ort = kwargs.pop("ort", False)
self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop(
"tie_word_embeddings", True
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,7 +1576,6 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.ort = config.ort

self.roberta = RobertaModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
Expand Down Expand Up @@ -1645,7 +1644,7 @@ def forward(
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) if not self.ort else 344
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)

Expand Down
16 changes: 3 additions & 13 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,23 +274,14 @@ def forward(self, hidden_states):
class T5ClampedDropout(nn.Module):
def __init__(self, config):
super().__init__()
self.ort = config.ort
self.dropout = nn.Dropout(config.dropout_rate)
self.dropout_rate = config.dropout_rate

def forward(self, hidden_states):
# clamp inf values to enable fp16 training
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
clamp_value = (1.0-self.dropout_rate)*clamp_value
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

hidden_states = self.dropout(hidden_states)
return hidden_states
Expand Down Expand Up @@ -671,7 +662,6 @@ class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__()
self.is_decoder = config.is_decoder
self.ort = config.ort
self.layer = nn.ModuleList()
self.layer.append(
T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
Expand Down
7 changes: 1 addition & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,12 +1916,7 @@ def _wrap_model(self, model, training=True, dataloader=None):

# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if self.accelerator.unwrap_model(model) is not model:
if self.args.ort:
from torch_ort import ORTModule
if type(model) is not ORTModule:
return model
else:
return model
return model

# Mixed precision training with apex (torch < 1.6)
if self.use_apex and training:
Expand Down
6 changes: 0 additions & 6 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,6 @@ class TrainingArguments:
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.
If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues
with hyperparameter tuning.
ort (:obj:`bool`, `optional`):
Use `ORTModule <https://github.com/microsoft/onnxruntime>`__.
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
Expand Down Expand Up @@ -1270,10 +1268,6 @@ class TrainingArguments:
)
},
)
ort: Optional[bool] = field(
default=False,
metadata={"help": "Enable Ort"},
)
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
Expand Down
Loading