-
Notifications
You must be signed in to change notification settings - Fork 0
/
main-multigpu-shared.py
executable file
·55 lines (40 loc) · 1.43 KB
/
main-multigpu-shared.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python
import itertools
import multiprocessing as mp
import os
import logging
import torch
import torch.distributed as dist
from common import MemoryMonitor
from dataset import TorchSharedTensorDataset
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
logger = logging.getLogger(__name__)
def main():
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
print(f"local_rank={local_rank}, world_size={world_size}")
monitor = MemoryMonitor()
ds = TorchSharedTensorDataset(is_rank0=(local_rank == 0), num_gbs=40)
print(monitor.table())
loader = torch.utils.data.DataLoader(ds,
num_workers=0,
shuffle=False,
batch_size=1)
pids = [os.getpid()]
all_pids = [None for _ in range(world_size)]
dist.all_gather_object(all_pids, pids)
all_pids = list(itertools.chain.from_iterable(all_pids))
monitor = MemoryMonitor(all_pids)
for _ in range(100):
for d in loader:
d = d.cuda()
if local_rank == 0:
print(monitor.table())
dist.barrier()
logger.warning(f'{local_rank}, {d.min()}, {d.max()}') # just make sure the data is correct
dist.barrier()
dist.destroy_process_group()
print('done')
if __name__ == "__main__":
main()