Skip to content

Commit

Permalink
remove utils.py. Sample models created in example files
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Nov 22, 2023
1 parent 4889e3b commit b215178
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 64 deletions.
19 changes: 16 additions & 3 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
Expand All @@ -9,7 +11,7 @@
ColwiseParallel,
RowwiseParallel,
)
from utils import ToyModel


try:
from torch.distributed.tensor.parallel import (
Expand Down Expand Up @@ -37,6 +39,17 @@
"""


class ToyModel(nn.Module):
""" MLP based model """
def __init__(self):
super().__init__()
self.in_proj = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 5)

def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))


"""
Main body of the demo of a basic version of sequence parallel by using
Expand Down Expand Up @@ -67,8 +80,8 @@ def rank_print(msg):
sp_model = parallelize_module(module = model,
device_mesh = device_mesh,
parallelize_plan = {
"net1": ColwiseParallel(input_layouts=Shard(0)),
"net2": RowwiseParallel(output_layouts=Shard(0)),
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
},
)

Expand Down
19 changes: 16 additions & 3 deletions distributed/tensor_parallelism/tensor_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
Expand All @@ -10,7 +12,8 @@
ColwiseParallel,
RowwiseParallel,
)
from utils import ToyModel




"""
Expand Down Expand Up @@ -45,6 +48,16 @@
"""


class ToyModel(nn.Module):
""" MLP based model """
def __init__(self):
super(ToyModel, self).__init__()
self.in_proj = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 5)

def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))

"""
Main body of the demo of a basic version of tensor parallel by using
Expand Down Expand Up @@ -88,8 +101,8 @@ def rank_print(msg):
tp_model = parallelize_module(module = tp_model,
device_mesh = device_mesh,
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
"in_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
)
# Perform a num of iterations of forward/backward
Expand Down
26 changes: 24 additions & 2 deletions distributed/tensor_parallelism/two_d_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed._tensor import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand All @@ -16,7 +18,6 @@
from torch.distributed._tensor.device_mesh import init_device_mesh
import os

from utils import MLP_swiglu


"""
Expand Down Expand Up @@ -49,8 +50,30 @@
More details can be seen in the slide:
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
"""
def find_multiple(n: int, k: int) -> int:
""" function to find resizing multiple for SwiGLU MLP """
if n % k == 0:
return n
return n + k - (n % k)


class MLP_swiglu(nn.Module):
""" SwiGLU to showcase a Llama style MLP model """

def __init__(self, mlp_dim: int= 1024) -> None:
super().__init__()
hidden_dim = 4 * mlp_dim
scaled_hidden = int(2 * hidden_dim / 3)
rounded_hidden = find_multiple(scaled_hidden, 256)

self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.in_proj(x)) * self.gate_proj(x)
x = self.out_proj(x)
return x

"""
Main body of the demo of a basic version of tensor parallel by using
Expand Down Expand Up @@ -107,7 +130,6 @@ def rank_print(msg):
base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to(_device)



# Custom parallelization plan for the swiglu MLP model
custom_tp_model = parallelize_module(module = base_model_swiglu,
device_mesh = tp_mesh,
Expand Down
56 changes: 0 additions & 56 deletions distributed/tensor_parallelism/utils.py

This file was deleted.

0 comments on commit b215178

Please sign in to comment.