Skip to content

Commit

Permalink
[2D] Update 2d example to use get_local_rank (#1203)
Browse files Browse the repository at this point in the history
update 2d example to use get_local_rank
  • Loading branch information
wz337 authored Dec 8, 2023
1 parent c0b889d commit 30b310a
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions distributed/tensor_parallelism/fsdp_tp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
tp_mesh = device_mesh["tp"]
dp_mesh = device_mesh["dp"]

# To support identical inputs for TP groups, we need the dp process group
dp_pg = device_mesh.get_dim_groups()[0]

# For TP, input needs to be same across all TP ranks.
# while for SP, input can be different across all ranks.
# We will use dp_rank for setting the random seed
# to mimic the behavior of the dataloader.
dp_rank = dist.get_rank(dp_pg)

dp_rank = dp_mesh.get_local_rank()

# create model and move it to GPU with id rank
_mlp_dim = 1024
Expand Down

0 comments on commit 30b310a

Please sign in to comment.