diff --git a/examples/distributed_inference/llama3_model.py b/examples/distributed_inference/llama3_model.py
deleted file mode 100644
index 9fa59b5c49..0000000000
--- a/examples/distributed_inference/llama3_model.py
+++ /dev/null
@@ -1,538 +0,0 @@
-# Taken and modified pytorch lightening
-# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
-
-
-from dataclasses import dataclass
-from typing import Any, Optional, Tuple
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-from torch.distributed._tensor import Replicate, Shard
-from torch.distributed.device_mesh import DeviceMesh
-from torch.distributed.tensor.parallel import (
-    ColwiseParallel,
-    PrepareModuleInput,
-    RowwiseParallel,
-    SequenceParallel,
-    parallelize_module,
-)
-
-
-@dataclass
-class ModelArgs:
-    dim: int = 4096
-    n_layers: int = 32
-    n_heads: int = 32
-    n_kv_heads: Optional[int] = None
-    vocab_size: int = -1  # defined later by tokenizer
-    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
-    ffn_dim_multiplier: Optional[float] = None
-    norm_eps: float = 1e-5
-    rope_theta: float = 10000
-
-    max_batch_size: int = 32
-    max_seq_len: int = 2048
-    # If `True`, then each transformer block init uses its layer ID, and if
-    # `False`, each uses the total number of transformer blocks
-    depth_init: bool = True
-    device: str = "cuda"
-
-
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
-    """Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
-
-    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
-    and the end index 'end'. The 'theta' parameter scales the frequencies.
-    The returned tensor contains complex values in complex64 data type.
-
-    Args:
-        dim (int): Dimension of the frequency tensor.
-        end (int): End index for precomputing frequencies.
-        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
-
-    Returns:
-        torch.Tensor: Precomputed frequency tensor with complex exponentials.
-
-    """
-    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
-    t = torch.arange(end, device=freqs.device)
-    freqs = torch.outer(t, freqs).float()
-    return torch.polar(torch.ones_like(freqs), freqs)  # complex64
-
-
-def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
-    """Reshape frequency tensor for broadcasting it with another tensor.
-
-    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
-    for the purpose of broadcasting the frequency tensor during element-wise operations.
-
-    The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
-    and the first seqlen elements will be sliced, but dim must match x.
-
-    Args:
-        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
-        x (torch.Tensor): Target tensor for broadcasting compatibility.
-
-    Returns:
-        torch.Tensor: Reshaped frequency tensor.
-
-    """
-    ndim = x.ndim
-    assert 0 <= 1 < ndim
-    seqlen = x.shape[1]
-    freqs_cis = freqs_cis[0:seqlen]
-    assert freqs_cis.shape == (seqlen, x.shape[-1])
-    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
-    return freqs_cis.view(*shape)
-
-
-def apply_rotary_emb(
-    xq: torch.Tensor,
-    xk: torch.Tensor,
-    freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Apply rotary embeddings to input tensors using the given frequency tensor.
-
-    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
-    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
-    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
-    returned as real tensors.
-
-    Args:
-        xq (torch.Tensor): Query tensor to apply rotary embeddings.
-        xk (torch.Tensor): Key tensor to apply rotary embeddings.
-        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
-
-    """
-    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
-    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
-    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
-    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
-    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
-    return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
-    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
-    bs, slen, n_kv_heads, head_dim = x.shape
-    if n_rep == 1:
-        return x
-    return (
-        x[:, :, :, None, :]
-        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
-        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
-    )
-
-
-class RMSNorm(nn.Module):
-    """Initialize the RMSNorm normalization layer.
-
-    Args:
-        dim (int): The dimension of the input tensor.
-        eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
-
-    Attributes:
-        eps (float): A small value added to the denominator for numerical stability.
-        weight (nn.Parameter): Learnable scaling parameter.
-
-    """
-
-    def __init__(self, dim: int, eps: float = 1e-6):
-        super().__init__()
-        self.eps = eps
-        self.weight = nn.Parameter(torch.ones(dim))
-
-    def _norm(self, x: torch.Tensor):
-        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
-
-    def forward(self, x: torch.Tensor):
-        output = self._norm(x.float()).type_as(x)
-        return output * self.weight
-
-    def reset_parameters(self):
-        torch.nn.init.ones_(self.weight)  # type: ignore
-
-
-class Attention(nn.Module):
-    """Multi-head attention module.
-
-    Args:
-        model_args (ModelArgs): Model configuration arguments.
-
-    Attributes:
-        n_kv_heads (int): Number of key and value heads.
-        n_heads (int): Number of query heads.
-        n_rep (int): Number of repetitions for local heads.
-        head_dim (int): Dimension size of each attention head.
-        wq (Linear): Linear transformation for queries.
-        wk (Linear): Linear transformation for keys.
-        wv (Linear): Linear transformation for values.
-        wo (Linear): Linear transformation for output.
-
-    """
-
-    def __init__(self, model_args: ModelArgs):
-        super().__init__()
-        self.n_heads = model_args.n_heads
-        self.n_kv_heads = (
-            model_args.n_heads
-            if model_args.n_kv_heads is None
-            else model_args.n_kv_heads
-        )
-        self.n_rep = self.n_heads // self.n_kv_heads
-        self.head_dim = model_args.dim // model_args.n_heads
-
-        self.wq = nn.Linear(
-            model_args.dim, model_args.n_heads * self.head_dim, bias=False
-        )
-        self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
-        self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
-        self.wo = nn.Linear(
-            model_args.n_heads * self.head_dim, model_args.dim, bias=False
-        )
-
-    def init_weights(self, init_std: float) -> None:
-        for linear in (self.wq, self.wk, self.wv):
-            nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
-        nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        freqs_cis: torch.Tensor,
-    ) -> Any:
-        """Forward pass of the attention module.
-
-        Args:
-            x (torch.Tensor): Input tensor.
-            freqs_cis (torch.Tensor): Precomputed frequency tensor.
-
-        Returns:
-            torch.Tensor: Output tensor after attention.
-
-        """
-        bs, seqlen, _ = x.shape
-        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
-
-        xq = xq.view(bs, seqlen, self.n_heads, self.head_dim)
-        xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim)
-        xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim)
-
-        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
-
-        # repeat k/v heads if n_kv_heads < n_heads
-        keys = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
-        values = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
-
-        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
-        xk = keys.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
-        xv = values.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
-
-        # we use casual mask for training
-        output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
-        output = output.transpose(
-            1, 2
-        ).contiguous()  # (bs, seqlen, n_local_heads, head_dim)
-        output = output.view(bs, seqlen, -1)
-        return self.wo(output)
-
-
-class FeedForward(nn.Module):
-    """FeedForward module.
-
-    Args:
-        dim (int): Input dimension.
-        hidden_dim (int): Hidden dimension of the feedforward layer.
-        multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
-        ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
-
-    Attributes:
-        w1 (Linear): Linear transformation for the first layer.
-        w2 (Linear): Linear transformation for the second layer.
-        w3 (Linear): Linear transformation for the third layer.
-
-    """
-
-    def __init__(
-        self,
-        dim: int,
-        hidden_dim: int,
-        multiple_of: int,
-        ffn_dim_multiplier: Optional[float],
-    ):
-        super().__init__()
-        hidden_dim = int(2 * hidden_dim / 3)
-        # custom dim factor multiplier
-        if ffn_dim_multiplier is not None:
-            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
-        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
-
-        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
-        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
-        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
-
-    def forward(self, x) -> Any:
-        return self.w2(F.silu(self.w1(x)) * self.w3(x))
-
-    def init_weights(self, init_std: float) -> None:
-        nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
-        for linear in (self.w2, self.w3):
-            nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
-
-
-class TransformerBlock(nn.Module):
-    """TransformerBlock Module.
-
-    Args:
-        layer_id (int): Identifier for the layer.
-        model_args (ModelArgs): Model configuration arguments.
-
-    Attributes:
-        n_heads (int): Number of attention heads.
-        dim (int): Dimension size of the model.
-        head_dim (int): Dimension size of each attention head.
-        attention (Attention): Attention module.
-        feed_forward (FeedForward): FeedForward module.
-        layer_id (int): Identifier for the layer.
-        attention_norm (RMSNorm): Layer normalization for attention output.
-        ffn_norm (RMSNorm): Layer normalization for feedforward output.
-
-    """
-
-    def __init__(self, layer_id: int, model_args: ModelArgs):
-        super().__init__()
-        self.n_heads = model_args.n_heads
-        self.dim = model_args.dim
-        self.attention = Attention(model_args)
-        self.feed_forward = FeedForward(
-            dim=model_args.dim,
-            hidden_dim=4 * model_args.dim,
-            multiple_of=model_args.multiple_of,
-            ffn_dim_multiplier=model_args.ffn_dim_multiplier,
-        )
-        self.layer_id = layer_id
-        self.num_layers = model_args.n_layers
-
-        self.attention_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
-        self.ffn_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
-
-        if model_args.depth_init:
-            self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
-        else:
-            self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        freqs_cis: torch.Tensor,
-    ):
-        """Perform a forward pass through the TransformerBlock.
-
-        Args:
-            x (torch.Tensor): Input tensor.
-            freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
-
-        Returns:
-            torch.Tensor: Output tensor after applying attention and feedforward layers.
-
-        """
-        h = x + self.attention(self.attention_norm(x), freqs_cis)
-        return h + self.feed_forward(self.ffn_norm(h))
-
-    def init_weights(self):
-        for norm in (self.attention_norm, self.ffn_norm):
-            norm.reset_parameters()
-        self.attention.init_weights(self.weight_init_std)
-        self.feed_forward.init_weights(self.weight_init_std)
-
-
-class ParallelTransformer(nn.Module):
-    """Transformer Module.
-
-    Args:
-        model_args (ModelArgs): Model configuration arguments.
-
-    Attributes:
-        model_args (ModelArgs): Model configuration arguments.
-        vocab_size (int): Vocabulary size.
-        n_layers (int): Number of layers in the model.
-        tok_embeddings (ParallelEmbedding): Token embeddings.
-        layers (torch.nn.ModuleList): List of Transformer blocks.
-        norm (RMSNorm): Layer normalization for the model output.
-        output (ColumnParallelLinear): Linear layer for final output.
-        freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
-
-    """
-
-    def __init__(self, model_args: ModelArgs, tp_mesh: DeviceMesh = None):
-        # Here we use distributed model initialization to avoid memory overflow
-        super().__init__()
-        self.model_args = model_args
-        self.vocab_size = model_args.vocab_size
-        self.n_layers = model_args.n_layers
-
-        self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
-        self.tok_embeddings.to(model_args.device)
-        self.tok_embeddings = self.parallel_embeddings(self.tok_embeddings, tp_mesh)
-
-        # TODO persistent should be set to false, since this buffer can be recomputed.
-        # however, we set it to true for 2 reasons.  (1) due to pytorch/pytorch#123411,
-        # compile or pipeline-tracer will not correctly handle non-persistent buffers,
-        # so we need to fix that.  (2) if we initialize pipeline-parallel models from
-        # a seed checkpoint rather than calling init_weights, we need freqs_cis to be
-        # initialized by the checkpoint, or we need to add a separate initializer for
-        # just the non-persistent buffers that is called after loading checkpoints.
-        self.register_buffer(
-            "freqs_cis",
-            self._precompute_freqs_cis().to(model_args.device),
-            persistent=True,
-        )
-
-        self.layers = torch.nn.ModuleDict().to(model_args.device)
-        for layer_id in range(model_args.n_layers):
-            block = TransformerBlock(layer_id, model_args).to(model_args.device)
-            self.layers[str(layer_id)] = block
-            self.parallel_transformer_block(self.layers[str(layer_id)], tp_mesh)
-
-        self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps).to(
-            model_args.device
-        )
-        self.norm = self.parallel_norm(self.norm, tp_mesh)
-        self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False).to(
-            model_args.device
-        )
-        self.output = self.parallel_output(self.output, tp_mesh)
-        self.init_weights()
-
-    def parallel_transformer_block(self, transformer_block, tp_mesh):
-        if tp_mesh.size() <= 1:
-            return
-        plan = {
-            "attention": PrepareModuleInput(
-                input_layouts=(Shard(1), None),
-                desired_input_layouts=(Replicate(), None),
-            ),
-            "attention.wq": ColwiseParallel(),
-            "attention.wk": ColwiseParallel(),
-            "attention.wv": ColwiseParallel(),
-            "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
-            "attention_norm": SequenceParallel(),
-            "feed_forward": PrepareModuleInput(
-                input_layouts=(Shard(1),),
-                desired_input_layouts=(Replicate(),),
-            ),
-            "feed_forward.w1": ColwiseParallel(),
-            "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
-            "feed_forward.w3": ColwiseParallel(),
-            "ffn_norm": SequenceParallel(),
-        }
-
-        # Adjust attention module to use the local number of heads
-        attn_layer = transformer_block.attention
-        attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
-        attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
-
-        # Apply the plan for the current transformer block
-        parallelize_module(transformer_block, tp_mesh, plan)
-
-    def parallel_embeddings(self, embedding, tp_mesh):
-        plan = {
-            "tok_embeddings": RowwiseParallel(
-                input_layouts=Replicate(),
-                output_layouts=Shard(1),
-            )
-        }
-        return parallelize_module(embedding, tp_mesh, plan)
-
-    def parallel_output(self, output, tp_mesh):
-        plan = {
-            "output": ColwiseParallel(
-                input_layouts=Shard(1),
-            ),
-        }
-        return parallelize_module(output, tp_mesh, plan)
-
-    def parallel_norm(self, norm, tp_mesh):
-        plan = {
-            "norm": SequenceParallel(),
-        }
-        return parallelize_module(norm, tp_mesh, plan)
-
-    def reset_parameters(self):
-        with torch.device(self.freqs_cis.device):
-            self.freqs_cis = self._precompute_freqs_cis()
-
-    def init_weights(self):
-        """[Note: On ``init_weights`` vs.
-
-        ``reset_parameters``]
-        Modules may define ``reset_parameters`` to initialize parameter values.
-        ``reset_parameters`` is meant to only initialize directly owned
-        parameters/buffers, not those of their child modules, and it can be
-        used to give the initial values for these tensors.
-        Separately, users may want custom initialization for their modules,
-        different from that in ``reset_parameters``. For this, we define
-        ``init_weights``. We only call it in the constructor of this
-        ``Transformer`` root module to avoid reinitializing tensors.
-
-        """
-        with torch.device(self.freqs_cis.device):
-            self.freqs_cis = self._precompute_freqs_cis()
-        nn.init.normal_(self.tok_embeddings.weight)
-        for layer in self.layers.values():
-            layer.init_weights()
-        self.norm.reset_parameters()
-        final_out_std = self.model_args.dim**-0.5
-        cutoff_factor = 3
-        nn.init.trunc_normal_(
-            self.output.weight,
-            mean=0.0,
-            std=final_out_std,
-            a=-cutoff_factor * final_out_std,
-            b=cutoff_factor * final_out_std,
-        )
-
-    def _precompute_freqs_cis(self) -> torch.Tensor:
-        return precompute_freqs_cis(
-            self.model_args.dim // self.model_args.n_heads,
-            # Need to compute until at least the max token limit for generation
-            # (use 2x max sequence length to be safe)
-            self.model_args.max_seq_len * 2,
-            self.model_args.rope_theta,
-        )
-
-    def forward(self, tokens: torch.Tensor):
-        """Perform a forward pass through the Transformer model.
-
-        Args:
-            tokens (torch.Tensor): Input token indices.
-
-        Returns:
-            torch.Tensor: Output logits after applying the Transformer model.
-
-        """
-        # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
-        h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
-
-        for layer in self.layers.values():
-            h = layer(h, self.freqs_cis)
-
-        h = self.norm(h) if self.norm else h
-        return self.output(h).float() if self.output else h
-
-    @classmethod
-    def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
-        """Initialize a Transformer model from a ModelArgs object.
-
-        Args:
-            model_args (ModelArgs): Model configuration arguments.
-
-        Returns:
-            Transformer: Transformer model.
-
-        """
-        return cls(model_args)
diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py
deleted file mode 100644
index 998c378be2..0000000000
--- a/examples/distributed_inference/tensor_parallel_llama3.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# Taken and modified pytorch lightening
-# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
-import logging
-import os
-import time
-
-import torch
-from llama3_model import ModelArgs, ParallelTransformer
-from tensor_parallel_initialize_dist import initialize_distributed_env
-from torch.distributed._composable.fsdp import MixedPrecisionPolicy
-from torch.distributed._composable.fsdp.fully_shard import fully_shard
-from torch.distributed._tensor import Replicate, Shard
-from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
-    checkpoint_wrapper,
-)
-
-device_mesh, _world_size, _rank, logger = initialize_distributed_env(
-    "./tensor_parallel_llama3"
-)
-# Import should be after initialization of the TRT-LLM plugin .so path
-import tensorrt_llm
-
-logger.info(f"Starting PyTorch TP example on rank {_rank}.")
-assert (
-    _world_size % 2 == 0
-), f"TP examples require even number of GPUs, but got {_world_size} gpus"
-
-model_args = ModelArgs(
-    vocab_size=32000,
-    dim=1024,
-    n_layers=4,
-    n_heads=8,
-    rope_theta=500000.0,
-    n_kv_heads=8,
-    device="cuda",
-)
-
-with torch.no_grad():
-    model = ParallelTransformer(model_args, device_mesh)
-    torch.manual_seed(0)
-    inp = torch.randint(32000, (8, 256), device="cuda")
-    python_result = model(inp)
-    torch_tensorrt.runtime.set_multi_device_safe_mode(True)
-    model = torch.compile(
-        model,
-        fullgraph=True,
-        backend="torch_tensorrt",
-        options={
-            "truncate_long_and_double": True,
-            "enabled_precisions": {torch.float32, torch.float16},
-            "use_python_runtime": True,
-            "workspace_size": 1 << 33,
-            "debug": False,
-            "use_aot_joint_export": False,
-        },
-        dynamic=False,
-    )
-    for i in range(15):
-        # seeding with dp_rank to ensure identical inputs for TP groups
-        torch.manual_seed(i)
-        start = time.time()
-        output = model(inp)
-        end = time.time()
-        if i == 0:
-            logger.info(f"Compilation time is {end-start}")
-            assert (
-                python_result - output
-            ).std() < 0.01, "Compilation result is not correct."
-        elif _rank == 0:
-            logger.info(f"Inference time is {end-start}")
diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py
index 837648fdb4..d2e3c590c6 100755
--- a/examples/distributed_inference/tensor_parallel_simple_example.py
+++ b/examples/distributed_inference/tensor_parallel_simple_example.py
@@ -2,6 +2,7 @@
 
 import tensorrt as trt
 import torch
+import torch.distributed as dist
 import torch.nn as nn
 import torch_tensorrt
 from tensor_parallel_initialize_dist import initialize_distributed_env
@@ -15,7 +16,6 @@
 device_mesh, _world_size, _rank, logger = initialize_distributed_env(
     "./tensor_parallel_simple_example"
 )
-import tensorrt_llm
 
 """
 This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
@@ -65,7 +65,6 @@ def forward(self, x):
 inp = torch.rand(20, 10, device="cuda")
 python_result = tp_model(inp)
 
-
 backend = "torch_tensorrt"
 tp_model = torch.compile(
     tp_model,
@@ -75,23 +74,28 @@ def forward(self, x):
         "enabled_precisions": {torch.float32, torch.float16},
         "use_python_runtime": True,
         "min_block_size": 1,
-        "use_aot_joint_export": False,
+        "use_distributed_mode_trace": True,
     },
-    dynamic=False,
+    dynamic=None,
 )
 
-for i in range(10):
-    # For TP, input needs to be same across all TP ranks.
-    # Setting the random seed is to mimic the behavior of dataloader.
-    torch.manual_seed(i)
-    inp = torch.rand(20, 10, device="cuda")
-    start = time.time()
-    output = tp_model(inp)
-    end = time.time()
-    if i == 0:
-        logger.info(f"Compilation time is {end-start}")
-        assert (
-            python_result - output
-        ).std() < 0.01, "Compilation result is not correct."
-    elif _rank == 0:
-        logger.info(f"Inference time is {end-start}")
+try:
+    for i in range(10):
+        # For TP, input needs to be same across all TP ranks.
+        # Setting the random seed is to mimic the behavior of dataloader.
+        torch.manual_seed(i)
+        inp = torch.rand(20, 10, device="cuda")
+        start = time.time()
+        output = tp_model(inp)
+        end = time.time()
+        if i == 0:
+            logger.info(f"Compilation time is {end-start}")
+            assert (
+                python_result - output
+            ).std() < 0.01, "Compilation result is not correct."
+        elif _rank == 0:
+            logger.info(f"Inference time is {end-start}")
+finally:
+    # This cleans up the distributed process group
+    if dist.is_initialized():
+        dist.destroy_process_group()
diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py
index ba404a4102..379a196e2e 100644
--- a/py/torch_tensorrt/dynamo/_defaults.py
+++ b/py/torch_tensorrt/dynamo/_defaults.py
@@ -46,9 +46,9 @@
 IMMUTABLE_WEIGHTS = True
 ENABLE_WEIGHT_STREAMING = False
 ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
-USE_AOT_JOINT_EXPORT = True
 TILING_OPTIMIZATION_LEVEL = "none"
 L2_LIMIT_FOR_TILING = -1
+USE_DISTRIBUTED_MODE_TRACE = False
 
 
 def default_device() -> Device:
diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py
index fc23ad76cf..d9b0e05e4d 100644
--- a/py/torch_tensorrt/dynamo/_settings.py
+++ b/py/torch_tensorrt/dynamo/_settings.py
@@ -35,7 +35,7 @@
     TILING_OPTIMIZATION_LEVEL,
     TIMING_CACHE_PATH,
     TRUNCATE_DOUBLE,
-    USE_AOT_JOINT_EXPORT,
+    USE_DISTRIBUTED_MODE_TRACE,
     USE_EXPLICIT_TYPING,
     USE_FAST_PARTITIONER,
     USE_FP32_ACC,
@@ -94,9 +94,9 @@ class CompilationSettings:
         enable_weight_streaming (bool): Enable weight streaming.
         enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
             True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
-        use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
         tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
         l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
+        use_distributed_mode_trace (bool):  Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
     """
 
     enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -137,9 +137,9 @@ class CompilationSettings:
     immutable_weights: bool = IMMUTABLE_WEIGHTS
     enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
     enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
-    use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
     tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
     l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
+    use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
 
 
 _SETTINGS_TO_BE_ENGINE_INVARIANT = (
diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py
index ef04745562..f3e1b3e1fa 100644
--- a/py/torch_tensorrt/dynamo/backend/backends.py
+++ b/py/torch_tensorrt/dynamo/backend/backends.py
@@ -10,11 +10,11 @@
 from torch._dynamo.backends.common import aot_autograd
 from torch._dynamo.utils import detect_fake_mode
 from torch._functorch.aot_autograd import aot_export_joint_simple
+from torch.distributed.tensor import DTensor
 from torch_tensorrt.dynamo import CompilationSettings
 from torch_tensorrt.dynamo._compiler import compile_module
 from torch_tensorrt.dynamo.lowering import (
     get_decompositions,
-    modify_reshape_complex_nodes,
     post_lowering,
     remove_detach,
     remove_sym_nodes,
@@ -52,25 +52,39 @@ def aot_torch_tensorrt_aten_backend(
     gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
 ) -> torch.nn.Module:
     settings, engine_cache = parse_dynamo_kwargs(kwargs)
-    if settings.use_aot_joint_export:
-        return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
-    logger.debug("Wrapping the backend with aot_autograd\n")
-    _pretraced_backend_autograd = functools.partial(
-        _pretraced_backend, settings=settings, engine_cache=engine_cache
-    )
-    settings_aot_autograd = {}
-    settings_aot_autograd["decompostions"] = get_decompositions(
-        settings.enable_experimental_decompositions
-    )
-    # This is added since detach lowering leads to alias nodes
-    # Error - View operation returned a tensor that is the same as the input base tensor
-    # torch nop_decompositions in torch/_decomp/decompositions.py
-    if aten.detach in settings_aot_autograd["decompositions"]:
-        del settings_aot_autograd["decompositions"][aten.detach]
-    return aot_autograd(
-        fw_compiler=_pretraced_backend_autograd,
-        decompositions=get_decompositions(settings.enable_experimental_decompositions),
-    )(gm, sample_inputs)
+
+    if settings.use_distributed_mode_trace:
+        logger.debug(
+            "Wrapping the backend with aot_autograd for Distributed examples\n"
+        )
+        _pretraced_backend_autograd = functools.partial(
+            _pretraced_backend, settings=settings, engine_cache=engine_cache
+        )
+        settings_aot_autograd = {}
+        settings_aot_autograd["decompositions"] = get_decompositions(
+            settings.enable_experimental_decompositions
+        )
+        # This is added since detach lowering leads to alias nodes
+        # Error - View operation returned a tensor that is the same as the input base tensor
+        # torch nop_decompositions in torch/_decomp/decompositions.py
+        # transpose key deleted since not desirable to lower it to permute
+        to_delete = {
+            key
+            for key in settings_aot_autograd["decompositions"]
+            if "detach" in key._name
+        }
+        for key in to_delete:
+            del settings_aot_autograd["decompositions"][key]
+
+        return aot_autograd(
+            fw_compiler=_pretraced_backend_autograd,
+            decompositions=settings_aot_autograd["decompositions"],
+        )(gm, sample_inputs)
+    if any(isinstance(tensor, DTensor) for tensor in sample_inputs):
+        logger.warning(
+            "It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple"
+        )
+    return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
 
 
 def _pretraced_backend(
@@ -110,18 +124,8 @@ def _pretraced_backend(
             # Remove detach nodes
             remove_detach(gm, settings)
 
-            complexInputIndices = []
-            for i, torch_input in enumerate(torch_inputs):
-                if torch_inputs[i].dtype == torch.complex64:
-                    complexInputIndices.append(i)
-                    torch_input_real = torch_inputs[i].real
-                    torch_input_imaginary = torch_inputs[i].imag
-                    torch_inputs[i] = torch.stack(
-                        (torch_input_real, torch_input_imaginary), dim=-1
-                    )
-
             # Invoke AOTAutograd to translate operators to aten
-            if settings.use_aot_joint_export:
+            if not settings.use_distributed_mode_trace:
                 gm = aot_export_joint_simple(
                     gm,
                     sample_inputs,
@@ -137,12 +141,6 @@ def _pretraced_backend(
 
             logger.debug("Lowered Input graph:\n " + str(gm.graph))
 
-            if complexInputIndices:
-                modify_reshape_complex_nodes(gm, complexInputIndices)
-                logger.debug(
-                    "Input graph after modifying complex nodes:\n " + str(gm.graph)
-                )
-
             torchtrt_inputs = prepare_inputs(
                 torch_inputs, disable_memory_format_check=True
             )
diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
index 17850fabce..79611c7552 100644
--- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
@@ -3,6 +3,7 @@
 import logging
 from typing import Dict, Sequence, Tuple, Union
 
+import tensorrt as trt
 from torch.fx.node import Argument, Target
 from torch_tensorrt.dynamo._SourceIR import SourceIR
 from torch_tensorrt.dynamo.conversion import impl
@@ -16,8 +17,6 @@
     tensorrt_fused_nccl_reduce_scatter_op,
 )
 
-import tensorrt as trt
-
 _LOGGER: logging.Logger = logging.getLogger(__name__)
 
 if load_tensorrt_llm():
@@ -30,7 +29,7 @@ def fused_nccl_gather(
         kwargs: Dict[str, Argument],
         name: str,
     ) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
-        return impl.distributed.nccl_gather(
+        return impl.nccl_ops.nccl_gather(
             ctx,
             target,
             SourceIR.ATEN,
@@ -46,7 +45,7 @@ def fused_nccl_reduce_scatter(
         kwargs: Dict[str, Argument],
         name: str,
     ) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
-        return impl.distributed.nccl_reduce_scatter(
+        return impl.nccl_ops.nccl_reduce_scatter(
             ctx,
             target,
             SourceIR.ATEN,
@@ -54,7 +53,6 @@ def fused_nccl_reduce_scatter(
             [args[0]],
         )
 
-    breakpoint()
 else:
     _LOGGER.debug(
         "Did not load torch.distributed converters since TensorRT-LLM is not available"
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py
index 013268f803..c28c5bcc7d 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py
@@ -3,12 +3,11 @@
 from typing import Optional, Tuple, Union
 
 import numpy as np
+import tensorrt as trt
 from torch.fx.node import Argument, Target
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
 from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name
 
-import tensorrt as trt
-
 
 # class for AllReduce
 class AllReduceStrategy(IntEnum):
@@ -94,7 +93,7 @@ def nccl_reduce_scatter(
         "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
     )
 
-    p_dtype = trt.float16
+    p_dtype = trt.float32
     pf_dtype = trt.PluginField(
         "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
     )
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
index f709f177d6..02cb2ccd56 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
@@ -49,7 +49,6 @@ def fuse_distributed_ops(
             == torch.ops._c10d_functional.wait_tensor.default
         ):
             wait_tensor_node = list(node.users)[0]
-            fused_op = None
             if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
                 with gm.graph.inserting_after(wait_tensor_node):
                     fused_node = gm.graph.create_node(
@@ -58,11 +57,12 @@ def fuse_distributed_ops(
                         args=(node.args[0], node.args[1], node.args[2]),
                     )
             else:
-                fused_node = gm.graph.create_node(
-                    op="call_function",
-                    target=tensorrt_fused_nccl_reduce_scatter_op,  # Define your custom fused function
-                    args=(node.args[0], node.args[1], node.args[2], node.args[3]),
-                )
+                with gm.graph.inserting_after(wait_tensor_node):
+                    fused_node = gm.graph.create_node(
+                        op="call_function",
+                        target=tensorrt_fused_nccl_reduce_scatter_op,  # Define your custom fused function
+                        args=(node.args[0], node.args[1], node.args[2], node.args[3]),
+                    )
 
             wait_tensor_node.replace_all_uses_with(fused_node)
             fused_node.meta.update(node.meta)
diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py
new file mode 100644
index 0000000000..e3062249fa
--- /dev/null
+++ b/tests/py/dynamo/distributed/distributed_utils.py
@@ -0,0 +1,60 @@
+import logging
+import os
+
+import numpy as np
+import tensorrt as trt
+import torch
+import torch.distributed as dist
+from torch.distributed._tensor.device_mesh import init_device_mesh
+
+
+def set_environment_variables_pytest():
+    os.environ["WORLD_SIZE"] = str(1)
+    os.environ["RANK"] = str(0)
+    os.environ["MASTER_ADDR"] = "127.0.0.1"
+    os.environ["MASTER_PORT"] = str(29500)
+    os.environ["USE_TRTLLM_PLUGINS"] = "1"
+
+
+def initialize_logger(rank, logger_file_name):
+    logger = logging.getLogger()
+    logger.setLevel(logging.INFO)
+    fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
+    fh.setLevel(logging.INFO)
+    logger.addHandler(fh)
+    return logger
+
+
+# This is required for env initialization since we use mpirun
+def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
+    local_rank = int(
+        os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
+    )
+    world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
+
+    # Set up environment variable to run with mpirun
+    os.environ["RANK"] = str(local_rank)
+    os.environ["WORLD_SIZE"] = str(world_size)
+    os.environ["MASTER_ADDR"] = "127.0.0.1"
+    os.environ["MASTER_PORT"] = str(port)
+    os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so"
+
+    # Necessary to assign a device to each rank.
+    torch.cuda.set_device(local_rank)
+
+    # We use nccl backend
+    dist.init_process_group("nccl")
+
+    # set a manual seed for reproducibility
+    torch.manual_seed(1111)
+
+    device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
+    rank = device_mesh.get_rank()
+    assert rank == local_rank
+    logger = initialize_logger(rank, logger_file_name)
+    device_id = (
+        rank % torch.cuda.device_count()
+    )  # Ensure each rank gets a unique device
+    torch.cuda.set_device(device_id)
+
+    return device_mesh, world_size, rank, logger
diff --git a/tests/py/dynamo/distributed/test_distributed_simple_example.py b/tests/py/dynamo/distributed/test_distributed_simple_example.py
new file mode 100644
index 0000000000..202469e2ea
--- /dev/null
+++ b/tests/py/dynamo/distributed/test_distributed_simple_example.py
@@ -0,0 +1,97 @@
+import time
+
+import tensorrt as trt
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch_tensorrt
+from distributed_utils import initialize_distributed_env
+from torch.distributed._tensor import Shard
+from torch.distributed.tensor.parallel import (
+    ColwiseParallel,
+    RowwiseParallel,
+    parallelize_module,
+)
+
+device_mesh, _world_size, _rank, logger = initialize_distributed_env(
+    "./tensor_parallel_simple_example"
+)
+
+"""
+This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
+"""
+
+
+class ToyModel(nn.Module):
+    """MLP based model"""
+
+    def __init__(self):
+        super(ToyModel, self).__init__()
+        self.in_proj = nn.Linear(10, 3200)
+        self.relu = nn.ReLU()
+        self.out_proj = nn.Linear(3200, 1600)
+        self.in_proj2 = nn.Linear(1600, 500)
+        self.out_proj2 = nn.Linear(500, 100)
+
+    def forward(self, x):
+        x = self.out_proj(self.relu(self.in_proj(x)))
+        x = self.relu(x)
+        x = self.out_proj2(self.relu(self.in_proj2(x)))
+        return x
+
+
+logger.info(f"Starting PyTorch TP example on rank {_rank}.")
+
+# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
+tp_model = ToyModel().to("cuda")
+
+
+# Custom parallelization plan for the model
+tp_model = parallelize_module(
+    module=tp_model,
+    device_mesh=device_mesh,
+    parallelize_plan={
+        "in_proj": ColwiseParallel(input_layouts=Shard(0)),
+        "out_proj": RowwiseParallel(output_layouts=Shard(0)),
+        "in_proj2": ColwiseParallel(input_layouts=Shard(0)),
+        "out_proj2": RowwiseParallel(output_layouts=Shard(0)),
+    },
+)
+torch.manual_seed(0)
+inp = torch.rand(20, 10, device="cuda")
+python_result = tp_model(inp)
+
+backend = "torch_tensorrt"
+tp_model = torch.compile(
+    tp_model,
+    backend=backend,
+    options={
+        "truncate_long_and_double": True,
+        "enabled_precisions": {torch.float32, torch.float16},
+        "use_python_runtime": True,
+        "min_block_size": 1,
+        "use_distributed_mode_trace": True,
+    },
+    dynamic=None,
+)
+
+try:
+    for i in range(10):
+        # For TP, input needs to be same across all TP ranks.
+        # Setting the random seed is to mimic the behavior of dataloader.
+        torch.manual_seed(i)
+        inp = torch.rand(20, 10, device="cuda")
+        start = time.time()
+        output = tp_model(inp)
+        end = time.time()
+        if i == 0:
+            logger.info(f"Compilation time is {end-start}")
+            assert (
+                python_result - output
+            ).std() < 0.01, "Compilation result is not correct."
+        elif _rank == 0:
+            logger.info(f"Inference time is {end-start}")
+finally:
+    # This cleans up the distributed process group
+    if dist.is_initialized():
+        dist.destroy_process_group()
diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py
new file mode 100644
index 0000000000..89c94300b7
--- /dev/null
+++ b/tests/py/dynamo/distributed/test_nccl_ops.py
@@ -0,0 +1,76 @@
+import os
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from distributed_utils import set_environment_variables_pytest
+from parameterized import parameterized
+from torch.testing._internal.common_utils import run_tests
+
+set_environment_variables_pytest()
+dist.init_process_group(backend="nccl", init_method="env://")
+group = dist.new_group(ranks=[0])
+group_name = group.group_name
+world_size = 1
+
+from conversion.harness import DispatchTestCase
+
+
+class TestGatherNcclOpsConverter(DispatchTestCase):
+    @parameterized.expand([8])
+    def test_nccl_ops(self, linear_layer_dim):
+        class DistributedGatherModel(nn.Module):
+            def __init__(self, input_dim):
+                super().__init__()
+                self.fc = torch.nn.Linear(input_dim, input_dim)
+
+            def forward(self, x):
+                x = self.fc(x)
+                gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor(
+                    x, world_size, group_name
+                )
+                gathered_tensor = torch.ops._c10d_functional.wait_tensor(
+                    gathered_tensor
+                )
+                return gathered_tensor
+
+        inputs = [torch.randn(1, linear_layer_dim).to("cuda")]
+        self.run_test(
+            DistributedGatherModel(linear_layer_dim).cuda(),
+            inputs,
+            use_dynamo_tracer=True,
+            enable_passes=True,
+        )
+
+    @parameterized.expand([8])
+    def test_nccl_ops_scatter(self, linear_layer_dim):
+
+        class DistributedReduceScatterModel(nn.Module):
+            def __init__(self, input_dim):
+                super().__init__()
+                self.fc = torch.nn.Linear(input_dim, input_dim)
+
+            def forward(self, x):
+                x = self.fc(x)
+                scatter_reduce_tensor = (
+                    torch.ops._c10d_functional.reduce_scatter_tensor(
+                        x, "sum", world_size, group_name
+                    )
+                )
+                scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor(
+                    scatter_reduce_tensor
+                )
+                return scatter_reduce_tensor
+
+        inputs = [torch.zeros(1, linear_layer_dim).to("cuda")]
+
+        self.run_test(
+            DistributedReduceScatterModel(linear_layer_dim).cuda(),
+            inputs,
+            use_dynamo_tracer=True,
+            enable_passes=True,
+        )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/tests/py/dynamo/distributed/test_nccl_ops.sh b/tests/py/dynamo/distributed/test_nccl_ops.sh
new file mode 100644
index 0000000000..dd54700048
--- /dev/null
+++ b/tests/py/dynamo/distributed/test_nccl_ops.sh
@@ -0,0 +1,137 @@
+#!/bin/bash
+
+check_command() {
+    command -v "$1" >/dev/null 2>&1
+}
+
+ensure_installed() {
+    local pkg="$1"
+    if ! check_command "$pkg"; then
+        echo "$pkg is not installed. Installing $pkg..."
+
+        # Determine if sudo is needed
+        if check_command sudo; then
+            SUDO="sudo"
+        else
+            SUDO=""
+        fi
+
+        # Detect OS and install accordingly
+        OS="$(uname -s)"
+        if [[ "$OS" == "Linux" ]]; then
+            if check_command apt-get; then
+                $SUDO apt-get update && $SUDO apt-get install -y "$pkg"
+            fi
+        else
+            echo "Unsupported OS: $OS. Please install $pkg manually."
+            exit 1
+        fi
+    else
+        echo "$pkg is already installed."
+    fi
+}
+
+ensure_mpi_installed() {
+    local pkg="$1"
+    if dpkg -l | grep -q "$pkg"; then
+        echo "$pkg is already installed."
+    else
+        echo "$pkg is not installed. Installing $pkg..."
+
+        # Determine if sudo is needed
+        if check_command sudo; then
+            SUDO="sudo"
+        else
+            SUDO=""
+        fi
+
+        # Detect OS and install accordingly
+        OS="$(uname -s)"
+        if [[ "$OS" == "Linux" ]]; then
+            if check_command apt-get; then
+                $SUDO apt-get update && $SUDO apt-get install -y "$pkg"
+            fi
+        else
+            echo "Unsupported OS: $OS. Please install $pkg manually."
+            exit 1
+        fi
+    fi
+}
+
+ensure_pytest_installed(){
+    if check_command pip; then
+        echo "pip is installed, installing pytest..."
+        pip install pytest
+    else
+        echo "pip is not installed. Please install pip first."
+        exit 1
+    fi
+}
+
+echo "Setting up the environment"
+
+OS="$(uname -s)"
+ARCH="$(uname -m)"
+
+
+#getting the file name for TensorRT-LLM download
+if [[ "$OS" == "Linux" && "$ARCH" == "x86_64"]]; then
+    FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_x86_64.whl"
+elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64"]]; then
+    FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_aarch64.whl"
+else:
+    echo "Unsupported platform: OS=$OS ARCH=$ARCH
+    exit 1
+fi
+
+# Download the selected file
+URL="https://pypi.nvidia.com/tensorrt-llm/$FILE"
+echo "Downloading $FILE from $URL..."
+
+#Installing wget
+ensure_installed wget
+
+#Downloading the file
+filename=$(basename "$URL")
+if [ -f "$filename" ]; then
+    echo "File already exists: $filename"
+else
+    wget "$URL"
+fi
+echo "Download complete: $FILE"
+
+UNZIP_DIR="tensorrt_llm_unzip"
+if [[ ! -d "$UNZIP_DIR" ]]; then
+    echo "Creating directory: $UNZIP_DIR"
+    mkdir -p "$UNZIP_DIR"
+    echo "extracting $FILE to $UNZIP_DIR ..."
+    #Installing unzip
+    ensure_installed unzip
+    #unzip the TensorRT-LLM package
+    unzip -q "$FILE" -d "$UNZIP_DIR"
+    echo "Unzip complete"
+fi
+
+
+export TRTLLM_PLUGINS_PATH="$(pwd)/${UNZIP_DIR}/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
+echo ${TRTLLM_PLUGINS_PATH}
+
+ensure_mpi_installed libmpich-dev
+ensure_mpi_installed libopenmpi-dev
+
+run_tests() {
+    cd ..
+    export PYTHONPATH=$(pwd)
+    echo "Running pytest on distributed/test_nccl_ops.py..."
+    pytest distributed/test_nccl_ops.py
+}
+
+run_mpi_tests(){
+    cd distributed
+    echo "Running test_distributed_simple_example with mpirun..."---
+    mpirun -n 1 --allow-run-as-root python test_distributed_simple_example.py
+}
+
+ensure_pytest_installed
+run_tests
+run_mpi_tests
\ No newline at end of file