Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DDP with rank-dependent dataloader lengths #3415

Closed
JAEarly opened this issue Jun 19, 2024 · 2 comments · Fixed by #3416
Closed

Support DDP with rank-dependent dataloader lengths #3415

JAEarly opened this issue Jun 19, 2024 · 2 comments · Fixed by #3416
Labels
enhancement New (engineering) enhancements, such as features or API changes.

Comments

@JAEarly
Copy link
Contributor

JAEarly commented Jun 19, 2024

🚀 Feature Request

Training and evaluation with DDP currently does not support rank-dependent dataloader lengths. Using 2 ranks as an example, it is not possible to have rank 0 with 100 batches and rank 1 with 120 batches, i.e. the number of batches has to be the same on each rank.

Motivation

I have a dataloader workflow that produces a variable number of samples dependent on the node rank (the details of which are not important for this issue IMO). When using DDP, this means the number of batches in each rank is different. Due to the implementation of trainer._train_loop and trainer._eval_loop (and possibly other parts of the trainer), using rank-dependent dataloader lengths leads to NCCL timeouts originating from various dist.all_reduce calls.

Returning to the example above and focusing on trainer._eval_loop: once rank 0 has completed its iteration, rank 1 hangs on dist.all_reduce as it is still iterating. Two causes (there may be others) of this that I have identified in trainer._eval_loop are:

  1. Last batch tracking
# If using a distributed sampler, keep track of last_batch for metrics update
if dist_sampler is not None and drop_last == False and dataset_len is not None:
    batch_num_samples_tensor = self.state.device.tensor_to_device(torch.tensor(rank_num_samples))
    dist.all_reduce(batch_num_samples_tensor, reduce_operation='SUM')
  1. _accumulate_time_across_ranks:
if isinstance(num_samples, float):
    sample_token_tensor = self.state.device.tensor_to_device(
        torch.tensor([num_samples, num_tokens], dtype=torch.float32),
    )
else:
    sample_token_tensor = self.state.device.tensor_to_device(
        torch.tensor([num_samples, num_tokens], dtype=torch.int),
    )
dist.all_reduce(sample_token_tensor, reduce_operation='SUM')

I have confirmed that avoiding these calls means trainer._eval_loop works with rank-dependent dataloader lengths.

(Potential) Implementation

I believe the solution could be to track if each rank has completed iterating, and if only one rank is left, avoid the use of dist.all_reduce. It is not needed in this case as there is nothing to sync.

NOTE: the easy fix in my use case is to use a fixed number of batches for training and evaluation, but I would ideally like to avoid that (it is fine for now though).

Please let me know if this feature is of interest; I'd be happy to look into a solution in more detail. Please also see a MWE example below to highlight this issue.

Additional context

MWE for reproducibility, focusing on trainer._eval_loop (rank-dependent val dataloader lengths):

from typing import Any

import torch
import torch.nn as nn
from composer import Callback, Event, Logger, State, Trainer
from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS
from composer.models import ComposerModel
from composer.optim import DecoupledAdamW
from composer.utils import dist
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric
from torchmetrics.classification import BinaryAccuracy

# Flag to toggle rank-dependent dataloader lengths
DIFF_EVAL_DATALOADER_LENS = True


# Synthetic binary dataset
class BinaryDataset(Dataset[dict[str, Tensor]]):
    def __init__(self, x: Tensor, y: Tensor) -> None:
        self.x = x
        self.y = y

    def __len__(self) -> int:
        return len(self.x)

    def __getitem__(self, idx: int) -> dict[str, Tensor]:
        return {"features": self.x[idx], "labels": self.y[idx]}


# Single layer NN
class SimpleLinearModel(ComposerModel):

    def __init__(self, input_dim: int) -> None:
        super().__init__()
        self.linear = nn.Linear(input_dim, 2)
        self.loss_fn = nn.CrossEntropyLoss()
        self.acc_metric = BinaryAccuracy(sync_on_compute=True, dist_sync_on_step=False)

    def forward(self, x: dict[str, Tensor]) -> Tensor:
        out: Tensor = self.linear(x["features"])
        return out

    def loss(self, outputs: Tensor, batch: dict[str, Tensor], *args: Any, **kwargs: Any) -> Tensor:
        loss: Tensor = self.loss_fn(outputs, batch["labels"])
        return loss

    def get_metrics(self, is_train: bool = False) -> dict[str, Metric]:
        return {} if is_train else {"accuracy": self.acc_metric}

    def update_metric(self, batch: dict[str, Tensor], outputs: Tensor, metric: Metric) -> None:
        metric.update(outputs.argmax(dim=1), batch["labels"])


# Callback to print all key events to help debugging
class DebugCallback(Callback):

    events = [
        Event.INIT,
        Event.BEFORE_LOAD,
        Event.AFTER_LOAD,
        Event.FIT_START,
        Event.ITERATION_START,
        Event.EPOCH_START,
        Event.EVAL_BEFORE_ALL,
        Event.EVAL_START,
        Event.EVAL_BATCH_START,
        Event.EVAL_BEFORE_FORWARD,
        Event.EVAL_AFTER_FORWARD,
        Event.EVAL_BATCH_END,
        Event.EVAL_END,
        Event.EVAL_AFTER_ALL,
        Event.EPOCH_CHECKPOINT,
        Event.ITERATION_END,
        Event.ITERATION_CHECKPOINT,
        Event.FIT_END,
    ]

    def run_event(self, event: Event, state: State, logger: Logger) -> None:
        if event in self.events:
            print(f"Event: {event}")


def build_dataloader(num_samples: int, num_features: int) -> DataLoader[dict[str, Tensor]]:
    x = torch.rand((num_samples, num_features))
    y = torch.randint(low=0, high=2, size=(num_samples,))
    dataset = BinaryDataset(x, y)
    dist_sampler = dist.get_sampler(dataset)
    dataloader = DataLoader(dataset, batch_size=16, sampler=dist_sampler)
    return dataloader


def get_best_accelerator() -> Device:
    if torch.cuda.is_available():
        return DeviceGPU()
    if torch.backends.mps.is_available():
        return DeviceMPS()
    return DeviceCPU()


def run() -> None:
    # Default values for dataset creation
    num_features = 10
    num_train_samples = 512
    num_val_samples = 256
    # Change rank 1 dataloader size
    if DIFF_EVAL_DATALOADER_LENS and dist.get_local_rank() == 1:
        num_val_samples += 256
    # Construct everything
    print("Building dataloaders...")
    train_dataloader = build_dataloader(num_train_samples, num_features)
    val_dataloader = build_dataloader(num_val_samples, num_features)
    print(f" Train Dataloader Len: {len(train_dataloader)}")
    print(f"   Val Dataloader Len: {len(val_dataloader)}")
    print("Building model...")
    model = SimpleLinearModel(num_features)
    optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)
    print("Building trainer...")
    trainer = Trainer(
        model=model,
        device=get_best_accelerator(),
        optimizers=optimizer,
        max_duration="2ep",
        log_to_console=True,
        console_log_interval="4ba",
        progress_bar=False,
        dist_timeout=30,
        callbacks=[DebugCallback()],
    )
    # Actually fit (train and eval)
    print("Fitting...")
    trainer.fit(train_dataloader=train_dataloader, eval_dataloader=val_dataloader)


if __name__ == "__main__":
    run()
@JAEarly JAEarly added the enhancement New (engineering) enhancements, such as features or API changes. label Jun 19, 2024
@JAEarly
Copy link
Contributor Author

JAEarly commented Jun 20, 2024

I've open #3416 to solve this issue by early terminating dataloader iteration for ranks with larger dataloaders.

@mvpatel2000
Copy link
Contributor

Agree we should address this -- will move discussion to PR (thanks for opening!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New (engineering) enhancements, such as features or API changes.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants