Skip to content

Commit

Permalink
code formatting via ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Nov 22, 2023
1 parent 2f4a083 commit 742966b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 37 deletions.
37 changes: 22 additions & 15 deletions distributed/tensor_parallelism/fsdp_tp_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import torch
import torch.distributed as dist
import torch.nn as nn
Expand Down Expand Up @@ -47,17 +46,19 @@
More details can be seen in the slide:
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
"""


def find_multiple(n: int, k: int) -> int:
""" function to find resizing multiple for SwiGLU MLP """
"""function to find resizing multiple for SwiGLU MLP"""
if n % k == 0:
return n
return n + k - (n % k)


class MLP_swiglu(nn.Module):
""" SwiGLU to showcase a Llama style MLP model """
"""SwiGLU to showcase a Llama style MLP model"""

def __init__(self, mlp_dim: int= 1024) -> None:
def __init__(self, mlp_dim: int = 1024) -> None:
super().__init__()
hidden_dim = 4 * mlp_dim
scaled_hidden = int(2 * hidden_dim / 3)
Expand All @@ -72,26 +73,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.out_proj(x)
return x


"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
"""
tp_size = 2

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logging.basicConfig(
format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO
)
logger = logging.getLogger(__name__)


# understand world topology
_rank = int(os.environ["RANK"])
_world_size = int(os.environ["WORLD_SIZE"])


#
def rank_log(msg):
"""helper function to print only on global rank 0"""
if _rank==0:
if _rank == 0:
logger.info(f" {msg}")


print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.")
assert (
_world_size % tp_size == 0
Expand All @@ -104,7 +110,7 @@ def rank_log(msg):
# Create a device mesh with 2 dimensions.
# First dim is the data parallel dimension
# Second dim is the tensor parallel dimension.
device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp","tp"))
device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))

rank_log(f"Device Mesh created: {device_mesh=}")
tp_mesh = device_mesh["tp"]
Expand All @@ -126,19 +132,20 @@ def rank_log(msg):


# Custom parallelization plan for the swiglu MLP model
custom_tp_model = parallelize_module(module = base_model_swiglu,
device_mesh = tp_mesh,
parallelize_plan = {
"in_proj": ColwiseParallel(),
"gate_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
custom_tp_model = parallelize_module(
module=base_model_swiglu,
device_mesh=tp_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(),
"gate_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
)

rank_log(f"Model after parallelization {custom_tp_model=}\n")

# Init FSDP using the dp device mesh
sharded_model = FSDP(custom_tp_model, device_mesh = dp_mesh, use_orig_params=True)
sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True)

# Create an optimizer for the parallelized and sharded model.
lr = 3e-3
Expand Down
31 changes: 20 additions & 11 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
)


logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logging.basicConfig(
format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO
)
logger = logging.getLogger(__name__)


Expand All @@ -33,8 +35,10 @@
in the end of the second linear layer.
"""


class ToyModel(nn.Module):
""" MLP based model """
"""MLP based model"""

def __init__(self):
super().__init__()
self.in_proj = nn.Linear(10, 32)
Expand All @@ -51,15 +55,19 @@ def forward(self, x):
"""

# create a device mesh based on the given world_size.
device_mesh = init_device_mesh(device_type = "cuda",mesh_shape = (int(os.environ["WORLD_SIZE"]),))
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),)
)

_rank = device_mesh.get_rank()


def rank_log(msg):
"""helper function to log only on global rank 0"""
if _rank==0:
if _rank == 0:
logger.info(f" {msg}")


print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.")

rank_log(f"Device Mesh created: {device_mesh=}")
Expand All @@ -68,12 +76,13 @@ def rank_log(msg):
model = ToyModel().to("cuda")

# Custom parallelization plan for the model
sp_model = parallelize_module(module = model,
device_mesh = device_mesh,
parallelize_plan = {
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
},
sp_model = parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
},
)


Expand All @@ -89,7 +98,7 @@ def rank_log(msg):

for i in range(num_iters):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10,device="cuda")
inp = torch.rand(20, 10, device="cuda")
output = sp_model(inp)
output.sum().backward()
optimizer.step()
Expand Down
31 changes: 20 additions & 11 deletions distributed/tensor_parallelism/tensor_parallel_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

import logging

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logging.basicConfig(
format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO
)
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -49,7 +51,8 @@


class ToyModel(nn.Module):
""" MLP based model """
"""MLP based model"""

def __init__(self):
super(ToyModel, self).__init__()
self.in_proj = nn.Linear(10, 32)
Expand All @@ -59,6 +62,7 @@ def __init__(self):
def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))


"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
Expand All @@ -67,16 +71,20 @@ def forward(self, x):
# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type = "cuda",mesh_shape = (_world_size,))
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()


def rank_log(msg):
"""helper function to print only on global rank 0"""
if _rank==0:
if _rank == 0:
logger.info(f" {msg}")


print(f"Starting PyTorch TP example on rank {_rank}.")
assert _world_size % 2 == 0, f"TP examples require even number of GPUs, but got {_world_size} gpus"
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"

rank_log(f"Device Mesh created: {device_mesh=}")

Expand All @@ -88,12 +96,13 @@ def rank_log(msg):
optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True)

# Custom parallelization plan for the model
tp_model = parallelize_module(module = tp_model,
device_mesh = device_mesh,
parallelize_plan = {
"in_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
tp_model = parallelize_module(
module=tp_model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
)
# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
Expand Down

0 comments on commit 742966b

Please sign in to comment.