diff --git a/training/run_distillation.py b/training/run_distillation.py index df592e3..84f96d6 100644 --- a/training/run_distillation.py +++ b/training/run_distillation.py @@ -133,7 +133,7 @@ class ModelArguments: "Which attention implementation to use in the encoder and decoder attention layers. Can be one of:\n" "1. `eager` or `None`: default Transformers attention implementation.\n" "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n" - "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)." + "3. `flash_attention_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)." ) }, ) @@ -144,7 +144,7 @@ def __post_init__(self): f"Got `--attn_implementation={self.attn_implementation}`, which is an invalid attention type. Should be one of:\n" "1. `eager` or `None`: default Transformers attention implementation.\n" "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n" - "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)." + "3. `flash_attention_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)." )