Skip to content

[SFT] support for ring_attn in SFTTrainer #3262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"mergekit": ["mergekit>=0.0.5.1"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
"ring_attn": ["ring-flash-attn"],
"scikit": ["scikit-learn"],
"bco": ["scikit-learn", "joblib"],
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
Expand Down
4 changes: 4 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"is_rich_available",
"is_unsloth_available",
"is_vllm_available",
"is_ring_attn_available",
],
"models": [
"SUPPORTED_ARCHITECTURES",
Expand All @@ -52,6 +53,7 @@
"PreTrainedModelWrapper",
"create_reference_model",
"setup_chat_format",
"register_ring_attn",
],
"trainer": [
"AlignPropConfig",
Expand Down Expand Up @@ -145,6 +147,7 @@
is_llm_blender_available,
is_mergekit_available,
is_rich_available,
is_ring_attn_available,
is_unsloth_available,
is_vllm_available,
)
Expand All @@ -154,6 +157,7 @@
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
create_reference_model,
register_ring_attn,
setup_chat_format,
)
from .scripts import ScriptArguments, TrlParser, init_zero_verbose
Expand Down
5 changes: 5 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_uvicorn_available = _is_package_available("uvicorn")
_vllm_available = _is_package_available("vllm")
_joblib_available = _is_package_available("joblib")
_ring_attn_available = _is_package_available("ring_flash_attn")


def is_deepspeed_available() -> bool:
Expand Down Expand Up @@ -84,6 +85,10 @@ def is_joblib_available() -> bool:
return _joblib_available


def is_ring_attn_available() -> bool:
return _ring_attn_available


class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
Expand Down
2 changes: 2 additions & 0 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"setup_chat_format",
"unwrap_model_for_generation",
],
"ring_attn": ["register_ring_attn"],
}

try:
Expand All @@ -45,6 +46,7 @@
if TYPE_CHECKING:
from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .ring_attn import register_ring_attn
from .utils import (
SUPPORTED_ARCHITECTURES,
prepare_deepspeed,
Expand Down
196 changes: 196 additions & 0 deletions trl/models/ring_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import logging

import torch
import torch.distributed as dist
import torch.nn.functional as F

from ..import_utils import is_ring_attn_available


logger = logging.getLogger(__name__)

RING_ATTN_GROUP = None


def get_ring_attn_group() -> dist.ProcessGroup | None:
"""
Getter for ring attention group on this rank.

Returns:
The process group for ring attention for this rank, or None if not initialized.
"""
return RING_ATTN_GROUP


def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
"""
Setter for ring attention group on this rank.

Args:
ring_attn_group: Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group


def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None = None):
"""
Create ring attention group and substitute flash attn with ring flash attn.

Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`. Defaults to 1.
"""
if get_ring_attn_group() is not None:
logger.info("Ring attention already registered, exiting early...")
return

if not dist.is_initialized():
logger.error("Distributed process group is not initialized. Cannot register ring attention.")
return

logger.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)

world_size = dist.get_world_size()
if sequence_parallel_degree > world_size:
raise ValueError(
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
if world_size % sequence_parallel_degree != 0:
raise ValueError(
f"sequence_parallel_degree ({sequence_parallel_degree}) must evenly divide world_size ({world_size})"
)

rank = dist.get_rank()
num_groups = world_size // sequence_parallel_degree
group_assignments = {}
local_group = None

# Create sequence parallel groups
for i in range(num_groups):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
# NCCL backend is assumed for GPU communication
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")

# Track which GPUs are in which groups for logging
for r in ring_attn_ranks:
group_assignments[r] = i

# Assign the group to the current rank if it belongs to this group
if rank in ring_attn_ranks:
local_group = group

if local_group is None:
# This should theoretically not happen if ranks cover 0 to world_size-1
# and checks above pass.
raise RuntimeError(f"Rank {rank} was not assigned to any ring attention group.")

set_ring_attn_group(local_group)

# Log the GPU group assignments from rank 0 for clarity
if rank == 0:
logger.info(f"Sequence parallel group assignments (GPU Rank -> Group Index): {group_assignments}")

if heads_k_stride is None:
heads_k_stride = 1

if is_ring_attn_available():
from ring_flash_attn import substitute_hf_flash_attn

substitute_hf_flash_attn(process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride)
logger.info("Successfully substituted HF flash attention with ring flash attention.")
else:
logger.error(
"Could not import `substitute_hf_flash_attn` from `ring_flash_attn`. "
"Please ensure the 'ring-flash-attn' package is installed."
)
# Reset the group if substitution fails to avoid inconsistent state
set_ring_attn_group(None)
raise ImportError("Could not import `substitute_hf_flash_attn` from `ring_flash_attn`.")


def get_cu_seqlens_from_pos_ids(
position_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)

device = position_ids.device
results = []
max_seq_lens = []

for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()

# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()

# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)])
# Append the padding length to the cumulative sequence lengths
if padding_length:
cu_seqlens = torch.cat([cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)])
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)

# Find the maximum value across all tensors
max_value = max(t.max() for t in results)

# Find the length of the longest tensor
max_length = max(t.size(0) for t in results)

# Pad each tensor to the same length and collect them in a list
padded_results = [F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results]

return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)


def update_ring_attn_params(batch: dict[str, torch.Tensor]):
"""
Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted `ring_flash_attn`.

Args:
batch: A dictionary with a batch of data. May or may not contain `position_ids`
data; if not, we compute it.
"""
from ring_flash_attn import update_ring_flash_attn_params

input_ids = batch["input_ids"]
position_ids = batch.get("position_ids")
if position_ids is None:
seq_len = input_ids.shape[1]
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0)

cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
29 changes: 22 additions & 7 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ class SFTConfig(TrainingArguments):
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`SFTTrainer`] is provided as a string.
use_liger (`bool` or `None`, *optional*, defaults to `None`):
Whether to use Liger kernel for sequence parallelism.
sequence_parallel_degree (`int` or `None`, *optional*, defaults to `None`):
Degree of sequence parallelism for ring attention.
heads_k_stride (`int` or `None`, *optional*, defaults to `None`):
Sequence parallelism K head stride size for ring attention. Defaults to 1 if sequence parallelism is enabled.

> Parameters that control the data preprocessing

Expand Down Expand Up @@ -78,6 +84,22 @@ class SFTConfig(TrainingArguments):
"the `SFTTrainer` is provided as a string."
},
)
use_liger: Optional[bool] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. Use `use_liger_kernel` "
"instead."
},
)
sequence_parallel_degree: Optional[int] = field(
default=None, metadata={"help": "Degree of sequence parallelism for ring attention."}
)
heads_k_stride: Optional[int] = field(
default=None,
metadata={
"help": "Sequence parallelism K head stride size for ring attention. Defaults to 1 if sequence parallelism is enabled."
},
)

# Parameters that control the data preprocessing
dataset_text_field: str = field(
Expand Down Expand Up @@ -169,13 +191,6 @@ class SFTConfig(TrainingArguments):
"help": "This parameter is deprecated and will be removed in version 0.20.0. Use `max_length` instead."
},
)
use_liger: Optional[bool] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. Use `use_liger_kernel` "
"instead."
},
)

def __post_init__(self):
super().__post_init__()
Expand Down
Loading