Replies: 1 comment 1 reply
-
Hi @civat , reduce-scatter have local reduction thus reduce the total communication volume. let's focus on rank0 in 8GPU setting (rank 0-7), it only send and receive 1/8 elements in reduce-scatter. Consider the 8 GPUs (rank 0-7) are in a ring. for first 1/8 gradient reduction, rank0 first send its local first 1/8 gradient (GPU0, first 1/8) to GPU 1, GPU1 add its own local first 1/8 graident (GPU1, first 1/8) to received (GPU0, first 1/8) and get partially reduce results as (GPU0+GPU1, first 1/8), then pass it to GPU 2, so on and forth. until GPU 7 finished local reduction and get result as (GPU0+GPU1... +GPU7, first 1/8), then GPU 7 send it back to GPU0 (i.e. GPU 0 receive (GPU0+GPU1... +GPU7, first 1/8)) Every GPU rank do exactly the same, (e.g., GPU rank 1 send & receive second 1/8 data, etc.) Thus every GPU send and receive 1/8 not 14/8. |
Beta Was this translation helpful? Give feedback.
-
In the paper "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models", the authors give a simple explanation for data parallel communication volume:
State-of-art implementation of all-reduce uses a two-step approach, where the first step is a reduce-scatter operation, which reduces different part of the data on different process. The next step is an all-gather operation where each process gathers the reduced data on all the process. The result of these two steps is an all-reduce. Both reduce-scatter and all-gather are implemented using a pipelined approach, that results in a total data movement of Ψ elements (for a data with Ψ elements) for each. Therefore, the standard DP incurs 2Ψ data movement during each training step.
My question is why reduce-scatter (or all-gather) results in total data movement of Ψ elements?
I'm not familiar to the collective operations in distributed environment. Here is my understanding. Suppose we have 8 ranks and data with Ψ elements. At initial, each rank has the full set of data. For any rank, say Rank1, it is responsible for the reduction of its corresponding data segment (Ψ/8 elements). To do this, Rank1 needs to collect 7Ψ/8 elements from the others, resulting in 7Ψ/8 data movement. Besides that, Rank1 needs to send its remaining 7Ψ/N elements to others for the reduction of the remaining segments. It seems to be another 7Ψ/8 data movement. So the total data movement for Rank1 is 14Ψ/8.
Which step is wrong?
If possible, pls. give a deep analysis for all-gather too. Thanks a lot!
Beta Was this translation helpful? Give feedback.
All reactions