diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 1017a7051e..adb134dd06 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -1,4 +1,3 @@ - import torch import torch.distributed as dist import torch.nn as nn @@ -47,17 +46,19 @@ More details can be seen in the slide: https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ """ + + def find_multiple(n: int, k: int) -> int: - """ function to find resizing multiple for SwiGLU MLP """ + """function to find resizing multiple for SwiGLU MLP""" if n % k == 0: return n return n + k - (n % k) class MLP_swiglu(nn.Module): - """ SwiGLU to showcase a Llama style MLP model """ + """SwiGLU to showcase a Llama style MLP model""" - def __init__(self, mlp_dim: int= 1024) -> None: + def __init__(self, mlp_dim: int = 1024) -> None: super().__init__() hidden_dim = 4 * mlp_dim scaled_hidden = int(2 * hidden_dim / 3) @@ -72,13 +73,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.out_proj(x) return x + """ Main body of the demo of a basic version of tensor parallel by using PyTorch native APIs. """ tp_size = 2 -logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO +) logger = logging.getLogger(__name__) @@ -86,12 +90,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: _rank = int(os.environ["RANK"]) _world_size = int(os.environ["WORLD_SIZE"]) + # def rank_log(msg): """helper function to print only on global rank 0""" - if _rank==0: + if _rank == 0: logger.info(f" {msg}") + print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.") assert ( _world_size % tp_size == 0 @@ -104,7 +110,7 @@ def rank_log(msg): # Create a device mesh with 2 dimensions. # First dim is the data parallel dimension # Second dim is the tensor parallel dimension. -device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp","tp")) +device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) rank_log(f"Device Mesh created: {device_mesh=}") tp_mesh = device_mesh["tp"] @@ -126,19 +132,20 @@ def rank_log(msg): # Custom parallelization plan for the swiglu MLP model -custom_tp_model = parallelize_module(module = base_model_swiglu, - device_mesh = tp_mesh, - parallelize_plan = { - "in_proj": ColwiseParallel(), - "gate_proj": ColwiseParallel(), - "out_proj": RowwiseParallel(), - }, +custom_tp_model = parallelize_module( + module=base_model_swiglu, + device_mesh=tp_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(), + "gate_proj": ColwiseParallel(), + "out_proj": RowwiseParallel(), + }, ) rank_log(f"Model after parallelization {custom_tp_model=}\n") # Init FSDP using the dp device mesh -sharded_model = FSDP(custom_tp_model, device_mesh = dp_mesh, use_orig_params=True) +sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True) # Create an optimizer for the parallelized and sharded model. lr = 3e-3 diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 203d2afa05..069e981b74 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -13,7 +13,9 @@ ) -logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO +) logger = logging.getLogger(__name__) @@ -33,8 +35,10 @@ in the end of the second linear layer. """ + class ToyModel(nn.Module): - """ MLP based model """ + """MLP based model""" + def __init__(self): super().__init__() self.in_proj = nn.Linear(10, 32) @@ -51,15 +55,19 @@ def forward(self, x): """ # create a device mesh based on the given world_size. -device_mesh = init_device_mesh(device_type = "cuda",mesh_shape = (int(os.environ["WORLD_SIZE"]),)) +device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),) +) _rank = device_mesh.get_rank() + def rank_log(msg): """helper function to log only on global rank 0""" - if _rank==0: + if _rank == 0: logger.info(f" {msg}") + print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.") rank_log(f"Device Mesh created: {device_mesh=}") @@ -68,12 +76,13 @@ def rank_log(msg): model = ToyModel().to("cuda") # Custom parallelization plan for the model -sp_model = parallelize_module(module = model, - device_mesh = device_mesh, - parallelize_plan = { - "in_proj": ColwiseParallel(input_layouts=Shard(0)), - "out_proj": RowwiseParallel(output_layouts=Shard(0)), - }, +sp_model = parallelize_module( + module=model, + device_mesh=device_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(input_layouts=Shard(0)), + "out_proj": RowwiseParallel(output_layouts=Shard(0)), + }, ) @@ -89,7 +98,7 @@ def rank_log(msg): for i in range(num_iters): # For SP, input can be different across all ranks. - inp = torch.rand(20, 10,device="cuda") + inp = torch.rand(20, 10, device="cuda") output = sp_model(inp) output.sum().backward() optimizer.step() diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 1b2bc073e0..cdf70f3799 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -12,7 +12,9 @@ import logging -logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO +) logger = logging.getLogger(__name__) @@ -49,7 +51,8 @@ class ToyModel(nn.Module): - """ MLP based model """ + """MLP based model""" + def __init__(self): super(ToyModel, self).__init__() self.in_proj = nn.Linear(10, 32) @@ -59,6 +62,7 @@ def __init__(self): def forward(self, x): return self.out_proj(self.relu(self.in_proj(x))) + """ Main body of the demo of a basic version of tensor parallel by using PyTorch native APIs. @@ -67,16 +71,20 @@ def forward(self, x): # create a device mesh based on the given world_size. _world_size = int(os.environ["WORLD_SIZE"]) -device_mesh = init_device_mesh(device_type = "cuda",mesh_shape = (_world_size,)) +device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) _rank = device_mesh.get_rank() + def rank_log(msg): """helper function to print only on global rank 0""" - if _rank==0: + if _rank == 0: logger.info(f" {msg}") + print(f"Starting PyTorch TP example on rank {_rank}.") -assert _world_size % 2 == 0, f"TP examples require even number of GPUs, but got {_world_size} gpus" +assert ( + _world_size % 2 == 0 +), f"TP examples require even number of GPUs, but got {_world_size} gpus" rank_log(f"Device Mesh created: {device_mesh=}") @@ -88,12 +96,13 @@ def rank_log(msg): optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True) # Custom parallelization plan for the model -tp_model = parallelize_module(module = tp_model, - device_mesh = device_mesh, - parallelize_plan = { - "in_proj": ColwiseParallel(), - "out_proj": RowwiseParallel(), - }, +tp_model = parallelize_module( + module=tp_model, + device_mesh=device_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(), + "out_proj": RowwiseParallel(), + }, ) # Perform a num of iterations of forward/backward # and optimizations for the sharded module.