diff --git a/examples/distributed_inference/llama3_model.py b/examples/distributed_inference/llama3_model.py deleted file mode 100644 index 9fa59b5c49..0000000000 --- a/examples/distributed_inference/llama3_model.py +++ /dev/null @@ -1,538 +0,0 @@ -# Taken and modified pytorch lightening -# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning - - -from dataclasses import dataclass -from typing import Any, Optional, Tuple - -import torch -import torch.nn.functional as F -from torch import nn -from torch.distributed._tensor import Replicate, Shard -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - PrepareModuleInput, - RowwiseParallel, - SequenceParallel, - parallelize_module, -) - - -@dataclass -class ModelArgs: - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: Optional[int] = None - vocab_size: int = -1 # defined later by tokenizer - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None - norm_eps: float = 1e-5 - rope_theta: float = 10000 - - max_batch_size: int = 32 - max_seq_len: int = 2048 - # If `True`, then each transformer block init uses its layer ID, and if - # `False`, each uses the total number of transformer blocks - depth_init: bool = True - device: str = "cuda" - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: - """Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex exponentials. - - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - freqs = torch.outer(t, freqs).float() - return torch.polar(torch.ones_like(freqs), freqs) # complex64 - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """Reshape frequency tensor for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' - for the purpose of broadcasting the frequency tensor during element-wise operations. - - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), - and the first seqlen elements will be sliced, but dim must match x. - - Args: - freqs_cis (torch.Tensor): Frequency tensor to be reshaped. - x (torch.Tensor): Target tensor for broadcasting compatibility. - - Returns: - torch.Tensor: Reshaped frequency tensor. - - """ - ndim = x.ndim - assert 0 <= 1 < ndim - seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Apply rotary embeddings to input tensors using the given frequency tensor. - - This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided - frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are - returned as real tensors. - - Args: - xq (torch.Tensor): Query tensor to apply rotary embeddings. - xk (torch.Tensor): Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - - """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -class RMSNorm(nn.Module): - """Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x: torch.Tensor): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: torch.Tensor): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) # type: ignore - - -class Attention(nn.Module): - """Multi-head attention module. - - Args: - model_args (ModelArgs): Model configuration arguments. - - Attributes: - n_kv_heads (int): Number of key and value heads. - n_heads (int): Number of query heads. - n_rep (int): Number of repetitions for local heads. - head_dim (int): Dimension size of each attention head. - wq (Linear): Linear transformation for queries. - wk (Linear): Linear transformation for keys. - wv (Linear): Linear transformation for values. - wo (Linear): Linear transformation for output. - - """ - - def __init__(self, model_args: ModelArgs): - super().__init__() - self.n_heads = model_args.n_heads - self.n_kv_heads = ( - model_args.n_heads - if model_args.n_kv_heads is None - else model_args.n_kv_heads - ) - self.n_rep = self.n_heads // self.n_kv_heads - self.head_dim = model_args.dim // model_args.n_heads - - self.wq = nn.Linear( - model_args.dim, model_args.n_heads * self.head_dim, bias=False - ) - self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear( - model_args.n_heads * self.head_dim, model_args.dim, bias=False - ) - - def init_weights(self, init_std: float) -> None: - for linear in (self.wq, self.wk, self.wv): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> Any: - """Forward pass of the attention module. - - Args: - x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed frequency tensor. - - Returns: - torch.Tensor: Output tensor after attention. - - """ - bs, seqlen, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) - xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) - xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - - # we use casual mask for training - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) - output = output.transpose( - 1, 2 - ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bs, seqlen, -1) - return self.wo(output) - - -class FeedForward(nn.Module): - """FeedForward module. - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. - - Attributes: - w1 (Linear): Linear transformation for the first layer. - w2 (Linear): Linear transformation for the second layer. - w3 (Linear): Linear transformation for the third layer. - - """ - - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x) -> Any: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - def init_weights(self, init_std: float) -> None: - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) - for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) - - -class TransformerBlock(nn.Module): - """TransformerBlock Module. - - Args: - layer_id (int): Identifier for the layer. - model_args (ModelArgs): Model configuration arguments. - - Attributes: - n_heads (int): Number of attention heads. - dim (int): Dimension size of the model. - head_dim (int): Dimension size of each attention head. - attention (Attention): Attention module. - feed_forward (FeedForward): FeedForward module. - layer_id (int): Identifier for the layer. - attention_norm (RMSNorm): Layer normalization for attention output. - ffn_norm (RMSNorm): Layer normalization for feedforward output. - - """ - - def __init__(self, layer_id: int, model_args: ModelArgs): - super().__init__() - self.n_heads = model_args.n_heads - self.dim = model_args.dim - self.attention = Attention(model_args) - self.feed_forward = FeedForward( - dim=model_args.dim, - hidden_dim=4 * model_args.dim, - multiple_of=model_args.multiple_of, - ffn_dim_multiplier=model_args.ffn_dim_multiplier, - ) - self.layer_id = layer_id - self.num_layers = model_args.n_layers - - self.attention_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) - self.ffn_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) - - if model_args.depth_init: - self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 - else: - self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - ): - """Perform a forward pass through the TransformerBlock. - - Args: - x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. - - Returns: - torch.Tensor: Output tensor after applying attention and feedforward layers. - - """ - h = x + self.attention(self.attention_norm(x), freqs_cis) - return h + self.feed_forward(self.ffn_norm(h)) - - def init_weights(self): - for norm in (self.attention_norm, self.ffn_norm): - norm.reset_parameters() - self.attention.init_weights(self.weight_init_std) - self.feed_forward.init_weights(self.weight_init_std) - - -class ParallelTransformer(nn.Module): - """Transformer Module. - - Args: - model_args (ModelArgs): Model configuration arguments. - - Attributes: - model_args (ModelArgs): Model configuration arguments. - vocab_size (int): Vocabulary size. - n_layers (int): Number of layers in the model. - tok_embeddings (ParallelEmbedding): Token embeddings. - layers (torch.nn.ModuleList): List of Transformer blocks. - norm (RMSNorm): Layer normalization for the model output. - output (ColumnParallelLinear): Linear layer for final output. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. - - """ - - def __init__(self, model_args: ModelArgs, tp_mesh: DeviceMesh = None): - # Here we use distributed model initialization to avoid memory overflow - super().__init__() - self.model_args = model_args - self.vocab_size = model_args.vocab_size - self.n_layers = model_args.n_layers - - self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) - self.tok_embeddings.to(model_args.device) - self.tok_embeddings = self.parallel_embeddings(self.tok_embeddings, tp_mesh) - - # TODO persistent should be set to false, since this buffer can be recomputed. - # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, - # compile or pipeline-tracer will not correctly handle non-persistent buffers, - # so we need to fix that. (2) if we initialize pipeline-parallel models from - # a seed checkpoint rather than calling init_weights, we need freqs_cis to be - # initialized by the checkpoint, or we need to add a separate initializer for - # just the non-persistent buffers that is called after loading checkpoints. - self.register_buffer( - "freqs_cis", - self._precompute_freqs_cis().to(model_args.device), - persistent=True, - ) - - self.layers = torch.nn.ModuleDict().to(model_args.device) - for layer_id in range(model_args.n_layers): - block = TransformerBlock(layer_id, model_args).to(model_args.device) - self.layers[str(layer_id)] = block - self.parallel_transformer_block(self.layers[str(layer_id)], tp_mesh) - - self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps).to( - model_args.device - ) - self.norm = self.parallel_norm(self.norm, tp_mesh) - self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False).to( - model_args.device - ) - self.output = self.parallel_output(self.output, tp_mesh) - self.init_weights() - - def parallel_transformer_block(self, transformer_block, tp_mesh): - if tp_mesh.size() <= 1: - return - plan = { - "attention": PrepareModuleInput( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), - ), - "attention.wq": ColwiseParallel(), - "attention.wk": ColwiseParallel(), - "attention.wv": ColwiseParallel(), - "attention.wo": RowwiseParallel(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), - "feed_forward": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - ), - "feed_forward.w1": ColwiseParallel(), - "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), - "feed_forward.w3": ColwiseParallel(), - "ffn_norm": SequenceParallel(), - } - - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - - # Apply the plan for the current transformer block - parallelize_module(transformer_block, tp_mesh, plan) - - def parallel_embeddings(self, embedding, tp_mesh): - plan = { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ) - } - return parallelize_module(embedding, tp_mesh, plan) - - def parallel_output(self, output, tp_mesh): - plan = { - "output": ColwiseParallel( - input_layouts=Shard(1), - ), - } - return parallelize_module(output, tp_mesh, plan) - - def parallel_norm(self, norm, tp_mesh): - plan = { - "norm": SequenceParallel(), - } - return parallelize_module(norm, tp_mesh, plan) - - def reset_parameters(self): - with torch.device(self.freqs_cis.device): - self.freqs_cis = self._precompute_freqs_cis() - - def init_weights(self): - """[Note: On ``init_weights`` vs. - - ``reset_parameters``] - Modules may define ``reset_parameters`` to initialize parameter values. - ``reset_parameters`` is meant to only initialize directly owned - parameters/buffers, not those of their child modules, and it can be - used to give the initial values for these tensors. - Separately, users may want custom initialization for their modules, - different from that in ``reset_parameters``. For this, we define - ``init_weights``. We only call it in the constructor of this - ``Transformer`` root module to avoid reinitializing tensors. - - """ - with torch.device(self.freqs_cis.device): - self.freqs_cis = self._precompute_freqs_cis() - nn.init.normal_(self.tok_embeddings.weight) - for layer in self.layers.values(): - layer.init_weights() - self.norm.reset_parameters() - final_out_std = self.model_args.dim**-0.5 - cutoff_factor = 3 - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=final_out_std, - a=-cutoff_factor * final_out_std, - b=cutoff_factor * final_out_std, - ) - - def _precompute_freqs_cis(self) -> torch.Tensor: - return precompute_freqs_cis( - self.model_args.dim // self.model_args.n_heads, - # Need to compute until at least the max token limit for generation - # (use 2x max sequence length to be safe) - self.model_args.max_seq_len * 2, - self.model_args.rope_theta, - ) - - def forward(self, tokens: torch.Tensor): - """Perform a forward pass through the Transformer model. - - Args: - tokens (torch.Tensor): Input token indices. - - Returns: - torch.Tensor: Output logits after applying the Transformer model. - - """ - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - - for layer in self.layers.values(): - h = layer(h, self.freqs_cis) - - h = self.norm(h) if self.norm else h - return self.output(h).float() if self.output else h - - @classmethod - def from_model_args(cls, model_args: ModelArgs) -> "Transformer": - """Initialize a Transformer model from a ModelArgs object. - - Args: - model_args (ModelArgs): Model configuration arguments. - - Returns: - Transformer: Transformer model. - - """ - return cls(model_args) diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py deleted file mode 100644 index 998c378be2..0000000000 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ /dev/null @@ -1,70 +0,0 @@ -# Taken and modified pytorch lightening -# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning -import logging -import os -import time - -import torch -from llama3_model import ModelArgs, ParallelTransformer -from tensor_parallel_initialize_dist import initialize_distributed_env -from torch.distributed._composable.fsdp import MixedPrecisionPolicy -from torch.distributed._composable.fsdp.fully_shard import fully_shard -from torch.distributed._tensor import Replicate, Shard -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, -) - -device_mesh, _world_size, _rank, logger = initialize_distributed_env( - "./tensor_parallel_llama3" -) -# Import should be after initialization of the TRT-LLM plugin .so path -import tensorrt_llm - -logger.info(f"Starting PyTorch TP example on rank {_rank}.") -assert ( - _world_size % 2 == 0 -), f"TP examples require even number of GPUs, but got {_world_size} gpus" - -model_args = ModelArgs( - vocab_size=32000, - dim=1024, - n_layers=4, - n_heads=8, - rope_theta=500000.0, - n_kv_heads=8, - device="cuda", -) - -with torch.no_grad(): - model = ParallelTransformer(model_args, device_mesh) - torch.manual_seed(0) - inp = torch.randint(32000, (8, 256), device="cuda") - python_result = model(inp) - torch_tensorrt.runtime.set_multi_device_safe_mode(True) - model = torch.compile( - model, - fullgraph=True, - backend="torch_tensorrt", - options={ - "truncate_long_and_double": True, - "enabled_precisions": {torch.float32, torch.float16}, - "use_python_runtime": True, - "workspace_size": 1 << 33, - "debug": False, - "use_aot_joint_export": False, - }, - dynamic=False, - ) - for i in range(15): - # seeding with dp_rank to ensure identical inputs for TP groups - torch.manual_seed(i) - start = time.time() - output = model(inp) - end = time.time() - if i == 0: - logger.info(f"Compilation time is {end-start}") - assert ( - python_result - output - ).std() < 0.01, "Compilation result is not correct." - elif _rank == 0: - logger.info(f"Inference time is {end-start}") diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 837648fdb4..d2e3c590c6 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -2,6 +2,7 @@ import tensorrt as trt import torch +import torch.distributed as dist import torch.nn as nn import torch_tensorrt from tensor_parallel_initialize_dist import initialize_distributed_env @@ -15,7 +16,6 @@ device_mesh, _world_size, _rank, logger = initialize_distributed_env( "./tensor_parallel_simple_example" ) -import tensorrt_llm """ This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py @@ -65,7 +65,6 @@ def forward(self, x): inp = torch.rand(20, 10, device="cuda") python_result = tp_model(inp) - backend = "torch_tensorrt" tp_model = torch.compile( tp_model, @@ -75,23 +74,28 @@ def forward(self, x): "enabled_precisions": {torch.float32, torch.float16}, "use_python_runtime": True, "min_block_size": 1, - "use_aot_joint_export": False, + "use_distributed_mode_trace": True, }, - dynamic=False, + dynamic=None, ) -for i in range(10): - # For TP, input needs to be same across all TP ranks. - # Setting the random seed is to mimic the behavior of dataloader. - torch.manual_seed(i) - inp = torch.rand(20, 10, device="cuda") - start = time.time() - output = tp_model(inp) - end = time.time() - if i == 0: - logger.info(f"Compilation time is {end-start}") - assert ( - python_result - output - ).std() < 0.01, "Compilation result is not correct." - elif _rank == 0: - logger.info(f"Inference time is {end-start}") +try: + for i in range(10): + # For TP, input needs to be same across all TP ranks. + # Setting the random seed is to mimic the behavior of dataloader. + torch.manual_seed(i) + inp = torch.rand(20, 10, device="cuda") + start = time.time() + output = tp_model(inp) + end = time.time() + if i == 0: + logger.info(f"Compilation time is {end-start}") + assert ( + python_result - output + ).std() < 0.01, "Compilation result is not correct." + elif _rank == 0: + logger.info(f"Inference time is {end-start}") +finally: + # This cleans up the distributed process group + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index ba404a4102..379a196e2e 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -46,9 +46,9 @@ IMMUTABLE_WEIGHTS = True ENABLE_WEIGHT_STREAMING = False ENABLE_CROSS_COMPILE_FOR_WINDOWS = False -USE_AOT_JOINT_EXPORT = True TILING_OPTIMIZATION_LEVEL = "none" L2_LIMIT_FOR_TILING = -1 +USE_DISTRIBUTED_MODE_TRACE = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index fc23ad76cf..d9b0e05e4d 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -35,7 +35,7 @@ TILING_OPTIMIZATION_LEVEL, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, - USE_AOT_JOINT_EXPORT, + USE_DISTRIBUTED_MODE_TRACE, USE_EXPLICIT_TYPING, USE_FAST_PARTITIONER, USE_FP32_ACC, @@ -94,9 +94,9 @@ class CompilationSettings: enable_weight_streaming (bool): Enable weight streaming. enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built. True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows - use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -137,9 +137,9 @@ class CompilationSettings: immutable_weights: bool = IMMUTABLE_WEIGHTS enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS - use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING + use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE _SETTINGS_TO_BE_ENGINE_INVARIANT = ( diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ef04745562..f3e1b3e1fa 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -10,11 +10,11 @@ from torch._dynamo.backends.common import aot_autograd from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import aot_export_joint_simple +from torch.distributed.tensor import DTensor from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo._compiler import compile_module from torch_tensorrt.dynamo.lowering import ( get_decompositions, - modify_reshape_complex_nodes, post_lowering, remove_detach, remove_sym_nodes, @@ -52,25 +52,39 @@ def aot_torch_tensorrt_aten_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any ) -> torch.nn.Module: settings, engine_cache = parse_dynamo_kwargs(kwargs) - if settings.use_aot_joint_export: - return _pretraced_backend(gm, sample_inputs, settings, engine_cache) - logger.debug("Wrapping the backend with aot_autograd\n") - _pretraced_backend_autograd = functools.partial( - _pretraced_backend, settings=settings, engine_cache=engine_cache - ) - settings_aot_autograd = {} - settings_aot_autograd["decompostions"] = get_decompositions( - settings.enable_experimental_decompositions - ) - # This is added since detach lowering leads to alias nodes - # Error - View operation returned a tensor that is the same as the input base tensor - # torch nop_decompositions in torch/_decomp/decompositions.py - if aten.detach in settings_aot_autograd["decompositions"]: - del settings_aot_autograd["decompositions"][aten.detach] - return aot_autograd( - fw_compiler=_pretraced_backend_autograd, - decompositions=get_decompositions(settings.enable_experimental_decompositions), - )(gm, sample_inputs) + + if settings.use_distributed_mode_trace: + logger.debug( + "Wrapping the backend with aot_autograd for Distributed examples\n" + ) + _pretraced_backend_autograd = functools.partial( + _pretraced_backend, settings=settings, engine_cache=engine_cache + ) + settings_aot_autograd = {} + settings_aot_autograd["decompositions"] = get_decompositions( + settings.enable_experimental_decompositions + ) + # This is added since detach lowering leads to alias nodes + # Error - View operation returned a tensor that is the same as the input base tensor + # torch nop_decompositions in torch/_decomp/decompositions.py + # transpose key deleted since not desirable to lower it to permute + to_delete = { + key + for key in settings_aot_autograd["decompositions"] + if "detach" in key._name + } + for key in to_delete: + del settings_aot_autograd["decompositions"][key] + + return aot_autograd( + fw_compiler=_pretraced_backend_autograd, + decompositions=settings_aot_autograd["decompositions"], + )(gm, sample_inputs) + if any(isinstance(tensor, DTensor) for tensor in sample_inputs): + logger.warning( + "It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple" + ) + return _pretraced_backend(gm, sample_inputs, settings, engine_cache) def _pretraced_backend( @@ -110,18 +124,8 @@ def _pretraced_backend( # Remove detach nodes remove_detach(gm, settings) - complexInputIndices = [] - for i, torch_input in enumerate(torch_inputs): - if torch_inputs[i].dtype == torch.complex64: - complexInputIndices.append(i) - torch_input_real = torch_inputs[i].real - torch_input_imaginary = torch_inputs[i].imag - torch_inputs[i] = torch.stack( - (torch_input_real, torch_input_imaginary), dim=-1 - ) - # Invoke AOTAutograd to translate operators to aten - if settings.use_aot_joint_export: + if not settings.use_distributed_mode_trace: gm = aot_export_joint_simple( gm, sample_inputs, @@ -137,12 +141,6 @@ def _pretraced_backend( logger.debug("Lowered Input graph:\n " + str(gm.graph)) - if complexInputIndices: - modify_reshape_complex_nodes(gm, complexInputIndices) - logger.debug( - "Input graph after modifying complex nodes:\n " + str(gm.graph) - ) - torchtrt_inputs = prepare_inputs( torch_inputs, disable_memory_format_check=True ) diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 17850fabce..79611c7552 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -3,6 +3,7 @@ import logging from typing import Dict, Sequence, Tuple, Union +import tensorrt as trt from torch.fx.node import Argument, Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl @@ -16,8 +17,6 @@ tensorrt_fused_nccl_reduce_scatter_op, ) -import tensorrt as trt - _LOGGER: logging.Logger = logging.getLogger(__name__) if load_tensorrt_llm(): @@ -30,7 +29,7 @@ def fused_nccl_gather( kwargs: Dict[str, Argument], name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.distributed.nccl_gather( + return impl.nccl_ops.nccl_gather( ctx, target, SourceIR.ATEN, @@ -46,7 +45,7 @@ def fused_nccl_reduce_scatter( kwargs: Dict[str, Argument], name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.distributed.nccl_reduce_scatter( + return impl.nccl_ops.nccl_reduce_scatter( ctx, target, SourceIR.ATEN, @@ -54,7 +53,6 @@ def fused_nccl_reduce_scatter( [args[0]], ) - breakpoint() else: _LOGGER.debug( "Did not load torch.distributed converters since TensorRT-LLM is not available" diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py index 013268f803..c28c5bcc7d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -3,12 +3,11 @@ from typing import Optional, Tuple, Union import numpy as np +import tensorrt as trt from torch.fx.node import Argument, Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name -import tensorrt as trt - # class for AllReduce class AllReduceStrategy(IntEnum): @@ -94,7 +93,7 @@ def nccl_reduce_scatter( "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 ) - p_dtype = trt.float16 + p_dtype = trt.float32 pf_dtype = trt.PluginField( "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py index f709f177d6..02cb2ccd56 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -49,7 +49,6 @@ def fuse_distributed_ops( == torch.ops._c10d_functional.wait_tensor.default ): wait_tensor_node = list(node.users)[0] - fused_op = None if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default: with gm.graph.inserting_after(wait_tensor_node): fused_node = gm.graph.create_node( @@ -58,11 +57,12 @@ def fuse_distributed_ops( args=(node.args[0], node.args[1], node.args[2]), ) else: - fused_node = gm.graph.create_node( - op="call_function", - target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function - args=(node.args[0], node.args[1], node.args[2], node.args[3]), - ) + with gm.graph.inserting_after(wait_tensor_node): + fused_node = gm.graph.create_node( + op="call_function", + target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function + args=(node.args[0], node.args[1], node.args[2], node.args[3]), + ) wait_tensor_node.replace_all_uses_with(fused_node) fused_node.meta.update(node.meta) diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py new file mode 100644 index 0000000000..e3062249fa --- /dev/null +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -0,0 +1,60 @@ +import logging +import os + +import numpy as np +import tensorrt as trt +import torch +import torch.distributed as dist +from torch.distributed._tensor.device_mesh import init_device_mesh + + +def set_environment_variables_pytest(): + os.environ["WORLD_SIZE"] = str(1) + os.environ["RANK"] = str(0) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(29500) + os.environ["USE_TRTLLM_PLUGINS"] = "1" + + +def initialize_logger(rank, logger_file_name): + logger = logging.getLogger() + logger.setLevel(logging.INFO) + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(logging.INFO) + logger.addHandler(fh) + return logger + + +# This is required for env initialization since we use mpirun +def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + + # Set up environment variable to run with mpirun + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so" + + # Necessary to assign a device to each rank. + torch.cuda.set_device(local_rank) + + # We use nccl backend + dist.init_process_group("nccl") + + # set a manual seed for reproducibility + torch.manual_seed(1111) + + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + logger = initialize_logger(rank, logger_file_name) + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank, logger diff --git a/tests/py/dynamo/distributed/test_distributed_simple_example.py b/tests/py/dynamo/distributed/test_distributed_simple_example.py new file mode 100644 index 0000000000..202469e2ea --- /dev/null +++ b/tests/py/dynamo/distributed/test_distributed_simple_example.py @@ -0,0 +1,97 @@ +import time + +import tensorrt as trt +import torch +import torch.distributed as dist +import torch.nn as nn +import torch_tensorrt +from distributed_utils import initialize_distributed_env +from torch.distributed._tensor import Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) + +device_mesh, _world_size, _rank, logger = initialize_distributed_env( + "./tensor_parallel_simple_example" +) + +""" +This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py +""" + + +class ToyModel(nn.Module): + """MLP based model""" + + def __init__(self): + super(ToyModel, self).__init__() + self.in_proj = nn.Linear(10, 3200) + self.relu = nn.ReLU() + self.out_proj = nn.Linear(3200, 1600) + self.in_proj2 = nn.Linear(1600, 500) + self.out_proj2 = nn.Linear(500, 100) + + def forward(self, x): + x = self.out_proj(self.relu(self.in_proj(x))) + x = self.relu(x) + x = self.out_proj2(self.relu(self.in_proj2(x))) + return x + + +logger.info(f"Starting PyTorch TP example on rank {_rank}.") + +# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. +tp_model = ToyModel().to("cuda") + + +# Custom parallelization plan for the model +tp_model = parallelize_module( + module=tp_model, + device_mesh=device_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(input_layouts=Shard(0)), + "out_proj": RowwiseParallel(output_layouts=Shard(0)), + "in_proj2": ColwiseParallel(input_layouts=Shard(0)), + "out_proj2": RowwiseParallel(output_layouts=Shard(0)), + }, +) +torch.manual_seed(0) +inp = torch.rand(20, 10, device="cuda") +python_result = tp_model(inp) + +backend = "torch_tensorrt" +tp_model = torch.compile( + tp_model, + backend=backend, + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float32, torch.float16}, + "use_python_runtime": True, + "min_block_size": 1, + "use_distributed_mode_trace": True, + }, + dynamic=None, +) + +try: + for i in range(10): + # For TP, input needs to be same across all TP ranks. + # Setting the random seed is to mimic the behavior of dataloader. + torch.manual_seed(i) + inp = torch.rand(20, 10, device="cuda") + start = time.time() + output = tp_model(inp) + end = time.time() + if i == 0: + logger.info(f"Compilation time is {end-start}") + assert ( + python_result - output + ).std() < 0.01, "Compilation result is not correct." + elif _rank == 0: + logger.info(f"Inference time is {end-start}") +finally: + # This cleans up the distributed process group + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py new file mode 100644 index 0000000000..89c94300b7 --- /dev/null +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -0,0 +1,76 @@ +import os + +import torch +import torch.distributed as dist +import torch.nn as nn +from distributed_utils import set_environment_variables_pytest +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +set_environment_variables_pytest() +dist.init_process_group(backend="nccl", init_method="env://") +group = dist.new_group(ranks=[0]) +group_name = group.group_name +world_size = 1 + +from conversion.harness import DispatchTestCase + + +class TestGatherNcclOpsConverter(DispatchTestCase): + @parameterized.expand([8]) + def test_nccl_ops(self, linear_layer_dim): + class DistributedGatherModel(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.fc = torch.nn.Linear(input_dim, input_dim) + + def forward(self, x): + x = self.fc(x) + gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( + x, world_size, group_name + ) + gathered_tensor = torch.ops._c10d_functional.wait_tensor( + gathered_tensor + ) + return gathered_tensor + + inputs = [torch.randn(1, linear_layer_dim).to("cuda")] + self.run_test( + DistributedGatherModel(linear_layer_dim).cuda(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand([8]) + def test_nccl_ops_scatter(self, linear_layer_dim): + + class DistributedReduceScatterModel(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.fc = torch.nn.Linear(input_dim, input_dim) + + def forward(self, x): + x = self.fc(x) + scatter_reduce_tensor = ( + torch.ops._c10d_functional.reduce_scatter_tensor( + x, "sum", world_size, group_name + ) + ) + scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor( + scatter_reduce_tensor + ) + return scatter_reduce_tensor + + inputs = [torch.zeros(1, linear_layer_dim).to("cuda")] + + self.run_test( + DistributedReduceScatterModel(linear_layer_dim).cuda(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/distributed/test_nccl_ops.sh b/tests/py/dynamo/distributed/test_nccl_ops.sh new file mode 100644 index 0000000000..dd54700048 --- /dev/null +++ b/tests/py/dynamo/distributed/test_nccl_ops.sh @@ -0,0 +1,137 @@ +#!/bin/bash + +check_command() { + command -v "$1" >/dev/null 2>&1 +} + +ensure_installed() { + local pkg="$1" + if ! check_command "$pkg"; then + echo "$pkg is not installed. Installing $pkg..." + + # Determine if sudo is needed + if check_command sudo; then + SUDO="sudo" + else + SUDO="" + fi + + # Detect OS and install accordingly + OS="$(uname -s)" + if [[ "$OS" == "Linux" ]]; then + if check_command apt-get; then + $SUDO apt-get update && $SUDO apt-get install -y "$pkg" + fi + else + echo "Unsupported OS: $OS. Please install $pkg manually." + exit 1 + fi + else + echo "$pkg is already installed." + fi +} + +ensure_mpi_installed() { + local pkg="$1" + if dpkg -l | grep -q "$pkg"; then + echo "$pkg is already installed." + else + echo "$pkg is not installed. Installing $pkg..." + + # Determine if sudo is needed + if check_command sudo; then + SUDO="sudo" + else + SUDO="" + fi + + # Detect OS and install accordingly + OS="$(uname -s)" + if [[ "$OS" == "Linux" ]]; then + if check_command apt-get; then + $SUDO apt-get update && $SUDO apt-get install -y "$pkg" + fi + else + echo "Unsupported OS: $OS. Please install $pkg manually." + exit 1 + fi + fi +} + +ensure_pytest_installed(){ + if check_command pip; then + echo "pip is installed, installing pytest..." + pip install pytest + else + echo "pip is not installed. Please install pip first." + exit 1 + fi +} + +echo "Setting up the environment" + +OS="$(uname -s)" +ARCH="$(uname -m)" + + +#getting the file name for TensorRT-LLM download +if [[ "$OS" == "Linux" && "$ARCH" == "x86_64"]]; then + FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_x86_64.whl" +elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64"]]; then + FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_aarch64.whl" +else: + echo "Unsupported platform: OS=$OS ARCH=$ARCH + exit 1 +fi + +# Download the selected file +URL="https://pypi.nvidia.com/tensorrt-llm/$FILE" +echo "Downloading $FILE from $URL..." + +#Installing wget +ensure_installed wget + +#Downloading the file +filename=$(basename "$URL") +if [ -f "$filename" ]; then + echo "File already exists: $filename" +else + wget "$URL" +fi +echo "Download complete: $FILE" + +UNZIP_DIR="tensorrt_llm_unzip" +if [[ ! -d "$UNZIP_DIR" ]]; then + echo "Creating directory: $UNZIP_DIR" + mkdir -p "$UNZIP_DIR" + echo "extracting $FILE to $UNZIP_DIR ..." + #Installing unzip + ensure_installed unzip + #unzip the TensorRT-LLM package + unzip -q "$FILE" -d "$UNZIP_DIR" + echo "Unzip complete" +fi + + +export TRTLLM_PLUGINS_PATH="$(pwd)/${UNZIP_DIR}/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" +echo ${TRTLLM_PLUGINS_PATH} + +ensure_mpi_installed libmpich-dev +ensure_mpi_installed libopenmpi-dev + +run_tests() { + cd .. + export PYTHONPATH=$(pwd) + echo "Running pytest on distributed/test_nccl_ops.py..." + pytest distributed/test_nccl_ops.py +} + +run_mpi_tests(){ + cd distributed + echo "Running test_distributed_simple_example with mpirun..."--- + mpirun -n 1 --allow-run-as-root python test_distributed_simple_example.py +} + +ensure_pytest_installed +run_tests +run_mpi_tests \ No newline at end of file