-
Notifications
You must be signed in to change notification settings - Fork 12
/
train_ddp.py
158 lines (122 loc) · 5.2 KB
/
train_ddp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Mostly based on the official pytorch tutorial
Link: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
Modified for educational purposes.
Nikolas, AI Summer
"""
import os
gpu_list = "0,1,2,3"
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
import time
import torchvision
from utils import setup_for_distributed, save_on_master, is_main_process
def create_data_loader_cifar10():
transform = transforms.Compose(
[
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 256
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=trainset, shuffle=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
sampler=train_sampler, num_workers=16, pin_memory=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset=testset, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, sampler=test_sampler, num_workers=16)
return trainloader, testloader
def train(net, trainloader):
print("Start training...")
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
epochs = 1
num_of_batches = len(trainloader)
for epoch in range(epochs): # loop over the dataset multiple times
trainloader.sampler.set_epoch(epoch)
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
images, labels = inputs.cuda(), labels.cuda()
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
print(f'[Epoch {epoch + 1}/{epochs}] loss: {running_loss / num_of_batches:.3f}')
print('Finished Training')
def test(net, PATH, testloader):
# if is_main_process:
# net.load_state_dict(torch.load(PATH))
# dist.barrier()
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.cuda(), labels.cuda()
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = 100 * correct // total
print(f'Accuracy of the network on the 10000 test images: {acc} %')
def init_distributed():
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
dist_url = "env://" # default
# only works with torch.distributed.launch // torch.run
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group(
backend="nccl",
init_method=dist_url,
world_size=world_size,
rank=rank)
# this will make all .cuda() calls work properly
torch.cuda.set_device(local_rank)
# synchronizes all the threads to reach this point before moving on
dist.barrier()
setup_for_distributed(rank == 0)
if __name__ == '__main__':
start = time.time()
init_distributed()
PATH = './cifar_net.pth'
trainloader, testloader = create_data_loader_cifar10()
net = torchvision.models.resnet50(False).cuda()
# Convert BatchNorm to SyncBatchNorm.
net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
local_rank = int(os.environ['LOCAL_RANK'])
net = nn.parallel.DistributedDataParallel(net, device_ids=[local_rank])
start_train = time.time()
train(net, trainloader)
end_train = time.time()
# save
if is_main_process:
save_on_master(net.state_dict(), PATH)
dist.barrier()
# test
test(net, PATH, testloader)
end = time.time()
seconds = (end - start)
seconds_train = (end_train - start_train)
print(f"Total elapsed time: {seconds:.2f} seconds, \
Train 1 epoch {seconds_train:.2f} seconds")