Skip to content

Commit ef30530

Browse files
gregmacnamarafacebook-github-bot
authored andcommitted
Correcting Perf Estimates for TWRW (#2782)
Summary: Correcting estimates for TWRW Comms in sharding perf estimators. FWD needs adjustment for size of A2A process group. BWD was missing A2A. Differential Revision: D69993523
1 parent b418a44 commit ef30530

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

torchrec/distributed/planner/shard_estimators.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -715,21 +715,35 @@ def _get_twrw_sharding_perf(
715715

716716
# inter host comm
717717
if world_size > local_world_size:
718-
inter_host_fwd_fwd_output_write_size = (
719-
batch_outputs * world_size * emb_dim * fwd_a2a_comm_data_type_size
720-
)
721-
fwd_comms += (
722-
inter_host_fwd_fwd_output_write_size
723-
* (local_world_size / world_size)
724-
/ inter_host_bw
718+
inter_host_fwd_output_write_size = (
719+
batch_outputs
720+
* (
721+
world_size / local_world_size
722+
) # this is the size of the procress group.
723+
* emb_dim
724+
* fwd_a2a_comm_data_type_size
725725
)
726+
fwd_comms += inter_host_fwd_output_write_size / inter_host_bw
726727

727728
fwd_compute = (
728729
input_read_size + embedding_lookup_size + fwd_output_write_size
729730
) / device_bw
730731

732+
# intra host comm (i.e. all gather)
731733
bwd_comms = bwd_output_write_size / intra_host_bw
732734

735+
# inter host comm (i.e. all to all)
736+
if world_size > local_world_size:
737+
inter_host_bwd_output_write_size = (
738+
batch_outputs
739+
* (
740+
world_size / local_world_size
741+
) # this is the size of the procress group.
742+
* emb_dim
743+
* bwd_a2a_comm_data_type_size
744+
)
745+
bwd_comms += inter_host_bwd_output_write_size / inter_host_bw
746+
733747
bwd_grad_indice_weights_kernel = (
734748
fwd_compute * WEIGHTED_KERNEL_MULTIPLIER if is_weighted else 0
735749
)

0 commit comments

Comments
 (0)