forked from Azure/msccl-tools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreducescatter_allpairs.py
40 lines (32 loc) · 1.5 KB
/
reducescatter_allpairs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import ReduceScatter
def allreduce_allpairs(gpus, protocol):
size = gpus
topology = fully_connected(size)
collective = ReduceScatter(gpus, gpus, True)
with MSCCLProgram("reducescatter_pairs", topology, collective, 1, protocol=protocol,
threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True):
# Each rank sends the nth chunk to the nth rank into scratch space
for r1 in range(size):
for r2 in range(size):
if r1 != r2:
index = r2 * size
c = chunk(r1, Buffer.input, index, size=size)
c.copy(r2, 'scratch', sendtb=r2, recvtb=r1)
# Each rank performs a local reduction on the nth chunk
# Utilize 8 threadblocks for this reduction for better parallelism
for r in range(size):
for index in range(0, size * (size-1)):
c = chunk(r, Buffer.input, r*size + (index % size))
c.reduce(chunk(r, 'scratch', index), sendtb=(index % size))
XML()
Check()
parser = argparse.ArgumentParser()
parser.add_argument('num_gpus', type=int, help ='number of gpus')
parser.add_argument('--protocol', type=str, default='LL', choices=['Simple', 'LL128', 'LL'], help='Protocol')
args = parser.parse_args()
allreduce_allpairs(args.num_gpus, args.protocol)